From 7400b34388326708d2925ca5a17bf7a5ce269c7a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 9 Oct 2023 11:12:36 -0700 Subject: [PATCH 1/6] Addressing unnecessary "Any" --- modules/core/src/test/scala/support/SwaggerSpecRunner.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/core/src/test/scala/support/SwaggerSpecRunner.scala b/modules/core/src/test/scala/support/SwaggerSpecRunner.scala index b5e03d45e6..ed57f2b76a 100644 --- a/modules/core/src/test/scala/support/SwaggerSpecRunner.scala +++ b/modules/core/src/test/scala/support/SwaggerSpecRunner.scala @@ -88,7 +88,9 @@ trait SwaggerSpecRunner extends EitherValues with OptionValues with TargetValues .foldLeft[(ProtocolDefinitions[L], Clients[L], Servers[L])]((ProtocolDefinitions(Nil, Nil, Nil, Nil, None), Clients(Nil, Nil), Servers(Nil, Nil))) { case ((proto, clients, servers), (generatedProto, generatedDefs)) => val newProto = - if ((proto.elems ++ proto.packageObjectContents ++ proto.packageObjectImports ++ proto.protocolImports ++ proto.implicitsObject.toList).nonEmpty) { + if ( + proto.elems.nonEmpty && proto.packageObjectContents.nonEmpty && proto.packageObjectImports.nonEmpty && proto.protocolImports.nonEmpty && proto.implicitsObject.nonEmpty + ) { proto } else { generatedProto From ffc555f80d05373f6840bc671ad497676234898a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 5 Oct 2023 14:56:50 -0700 Subject: [PATCH 2/6] Undo formatting snafu in docstring --- .../main/scala/dev/guardrail/core/Tracker.scala | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/modules/core/src/main/scala/dev/guardrail/core/Tracker.scala b/modules/core/src/main/scala/dev/guardrail/core/Tracker.scala index 514d5648c6..66e7ec1a76 100644 --- a/modules/core/src/main/scala/dev/guardrail/core/Tracker.scala +++ b/modules/core/src/main/scala/dev/guardrail/core/Tracker.scala @@ -14,13 +14,18 @@ import cats.Functor * Tracker heavily utilizes syntax from IndexedFunctor and IndexedDistributive. These classes are similar to Functor and Traversable, except they expose the * index into the structure they are walking over. This is used to automatically build the history while walking structures. * - * val tracker: Tracker[ OpenAPI ] = Tracker(openAPI) val servers: Tracker[List[ Server ]] = tracker.downField("servers", _.getServers) val firstServer: - * Tracker[Option[ Server ]] = servers.map(_.headOption) val firstServerUrl: Tracker[Option[ String ]] = firstServer.flatDownField("url", _.getUrl) + * ```scala + * val tracker: Tracker[OpenAPI] = Tracker(openAPI) + * val servers: Tracker[List[Server]] = tracker.downField("servers", _.getServers) + * val firstServer: Tracker[Option[Server]] = servers.map(_.headOption) + * val firstServerUrl: Tracker[Option[String]] = firstServer.flatDownField("url", _.getUrl) * - * val trackedUrl: Tracker[Option[ URL ]] = firstServerUrl.map(_.map(new URL(_))) + * val trackedUrl: Tracker[Option[URL]] = firstServerUrl.map(_.map(new URL(_))) * - * // Examples of extracting: val firstServerUrl: Option[ URL ] = trackedUrl.unwrapTracker // Throw away history val firstServerUrl: Target[ URL ] = - * trackedUrl.raiseErrorIfEmpty("No Server URL found!") // Append history to the end of the error message + * // Examples of extracting: + * val firstServerUrl: Option[URL] = trackedUrl.unwrapTracker // Throw away history + * val firstServerUrl: Target[URL] = trackedUrl.raiseErrorIfEmpty("No Server URL found!") // Append history to the end of the error message + * ``` */ class Tracker[+A] private[core] (private[core] val get: A, private[core] val history: Vector[String]) { override def toString(): String = s"Tracker($get, $history)" From 8126577965e18c9199d72e1d8bd1e4b212f487ed Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 7 Oct 2023 22:35:42 -0700 Subject: [PATCH 3/6] Resolve GHA set-output deprecation --- .github/workflows/ci.yml | 2 +- .github/workflows/release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 00ba5837d4..e659ff8adf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,7 +53,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} - name: Set scala version for matrix id: set-scala-versions # The greps on the following line are to ensure as much as possible that we've caught the right line - run: echo "::set-output name=scala_versions::$(sbt 'print githubMatrixSettings' | grep '^\[{' | grep 'bincompat' | tail -n 1)" + run: echo "scala_versions=$(sbt 'print githubMatrixSettings' | grep '^\[{' | grep 'bincompat' | tail -n 1)" >> $GITHUB_OUTPUT java: runs-on: ubuntu-20.04 needs: [core] diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 90d53b5f41..325f74badd 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,7 +16,7 @@ jobs: run: | module="$(echo "$GITHUB_REF" | sed 's~^refs/tags/\(.*\)-v[0-9.]\+$~\1~')" echo "extract project: ${GITHUB_REF}, ${module}" - echo "::set-output name=module::$module" + echo "module=$module" >> $GITHUB_OUTPUT - uses: actions/checkout@v2 with: fetch-depth: 0 From 3d2948f4975d06f7ad1a5e8e7eaf586ec1deb55f Mon Sep 17 00:00:00 2001 From: Aleksei Lezhoev Date: Sun, 30 Jul 2023 11:19:22 +0200 Subject: [PATCH 4/6] Degeneralize server generator --- .../src/main/scala/dev/guardrail/Common.scala | 5 +- .../generators/ServerGenerator.scala | 120 --------- .../dev/guardrail/generators/Servers.scala | 17 ++ .../guardrail/terms/server/ServerTerms.scala | 179 +------------- .../DropwizardServerGenerator.scala | 198 +++++++++++---- .../springMvc/SpringMvcServerGenerator.scala | 194 +++++++++++---- .../akkaHttp/AkkaHttpServerGenerator.scala | 197 ++++++++++++--- .../DropwizardServerGenerator.scala | 172 ++++++++++--- .../scala/http4s/Http4sServerGenerator.scala | 228 +++++++++++++----- 9 files changed, 813 insertions(+), 497 deletions(-) delete mode 100644 modules/core/src/main/scala/dev/guardrail/generators/ServerGenerator.scala create mode 100644 modules/core/src/main/scala/dev/guardrail/generators/Servers.scala diff --git a/modules/core/src/main/scala/dev/guardrail/Common.scala b/modules/core/src/main/scala/dev/guardrail/Common.scala index 5a41f877a5..eae1104094 100644 --- a/modules/core/src/main/scala/dev/guardrail/Common.scala +++ b/modules/core/src/main/scala/dev/guardrail/Common.scala @@ -8,7 +8,7 @@ import java.nio.file.Path import java.net.URI import dev.guardrail.core.{ SupportDefinition, Tracker } -import dev.guardrail.generators.{ ClientGenerator, Clients, ProtocolDefinitions, ProtocolGenerator, ServerGenerator, Servers } +import dev.guardrail.generators.{ ClientGenerator, Clients, ProtocolDefinitions, ProtocolGenerator, Servers } import dev.guardrail.languages.LA import dev.guardrail.terms.client.ClientTerms import dev.guardrail.terms.framework.FrameworkTerms @@ -95,8 +95,7 @@ object Common { case CodegenTarget.Server => for { - serverMeta <- ServerGenerator - .fromSwagger[L, F](context, supportPackage, basePath, frameworkImports)(groupedRoutes)(protocolElems, securitySchemes, components) + serverMeta <- Se.fromSwagger(context, supportPackage, basePath, frameworkImports)(groupedRoutes)(protocolElems, securitySchemes, components) Servers(servers, supportDefinitions) = serverMeta frameworkImplicits <- getFrameworkImplicits() } yield CodegenDefinitions[L](List.empty, servers, supportDefinitions, frameworkImplicits) diff --git a/modules/core/src/main/scala/dev/guardrail/generators/ServerGenerator.scala b/modules/core/src/main/scala/dev/guardrail/generators/ServerGenerator.scala deleted file mode 100644 index b7158566b7..0000000000 --- a/modules/core/src/main/scala/dev/guardrail/generators/ServerGenerator.scala +++ /dev/null @@ -1,120 +0,0 @@ -package dev.guardrail.generators - -import cats.data.NonEmptyList -import cats.syntax.all._ - -import dev.guardrail._ -import dev.guardrail.core.SupportDefinition -import dev.guardrail.languages.LA -import dev.guardrail.terms.Responses -import dev.guardrail.terms.framework.FrameworkTerms -import dev.guardrail.terms.protocol.StrictProtocolElems -import dev.guardrail.terms.server.{ GenerateRouteMeta, SecurityExposure, ServerTerms } -import dev.guardrail.terms.{ CollectionsLibTerms, LanguageTerms, RouteMeta, SecurityScheme, SwaggerTerms } -import dev.guardrail.core.Tracker -import io.swagger.v3.oas.models.Components - -case class Servers[L <: LA](servers: List[Server[L]], supportDefinitions: List[SupportDefinition[L]]) -case class Server[L <: LA](pkg: List[String], extraImports: List[L#Import], handlerDefinition: L#Definition, serverDefinitions: List[L#Definition]) -case class CustomExtractionField[L <: LA](param: LanguageParameter[L], term: L#Term) -case class TracingField[L <: LA](param: LanguageParameter[L], term: L#Term) -case class RenderedRoutes[L <: LA]( - routes: List[L#Statement], - classAnnotations: List[L#Annotation], - methodSigs: List[L#MethodDeclaration], - supportDefinitions: List[L#Definition], - handlerDefinitions: List[L#Statement], - securitySchemesDefinitions: List[L#Definition] -) - -object ServerGenerator { - def fromSwagger[L <: LA, F[_]](context: Context, supportPackage: NonEmptyList[String], basePath: Option[String], frameworkImports: List[L#Import])( - groupedRoutes: List[(List[String], List[RouteMeta])] - )( - protocolElems: List[StrictProtocolElems[L]], - securitySchemes: Map[String, SecurityScheme[L]], - components: Tracker[Option[Components]] - )(implicit Fw: FrameworkTerms[L, F], Sc: LanguageTerms[L, F], Cl: CollectionsLibTerms[L, F], S: ServerTerms[L, F], Sw: SwaggerTerms[L, F]): F[Servers[L]] = { - import S._ - import Sw._ - import Sc._ - - for { - extraImports <- getExtraImports(context.tracing, supportPackage) - supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) - servers <- groupedRoutes.traverse { case (className, unsortedRoutes) => - val routes = unsortedRoutes - .groupBy(_.path.unwrapTracker.indexOf('{')) - .view - .mapValues(_.sortBy(r => (r.path.unwrapTracker, r.method))) - .toList - .sortBy(_._1) - .flatMap(_._2) - for { - resourceName <- formatTypeName(className.lastOption.getOrElse(""), Some("Resource")) - handlerName <- formatTypeName(className.lastOption.getOrElse(""), Some("Handler")) - - responseServerPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => - for { - operationId <- getOperationId(operation) - responses <- Responses.getResponses(operationId, operation, protocolElems, components) - responseClsName <- formatTypeName(operationId, Some("Response")) - responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) - methodName <- formatMethodName(operationId) - parameters <- route.getParameters[L, F](components, protocolElems) - customExtractionField <- buildCustomExtractionFields(operation, className, context.customExtraction) - tracingField <- buildTracingFields(operation, className, context.tracing) - } yield ( - responseDefinitions, - GenerateRouteMeta(operationId, methodName, responseClsName, customExtractionField, tracingField, route, parameters, responses) - ) - } - (responseDefinitions, serverOperations) = responseServerPair.unzip - securityExposure = serverOperations.flatMap(_.routeMeta.securityRequirements) match { - case Nil => SecurityExposure.Undefined - case xs => if (xs.exists(_.optional)) SecurityExposure.Optional else SecurityExposure.Required - } - renderedRoutes <- generateRoutes( - context.tracing, - resourceName, - handlerName, - basePath, - serverOperations, - protocolElems, - securitySchemes, - securityExposure, - context.authImplementation - ) - handlerSrc <- renderHandler( - handlerName, - renderedRoutes.methodSigs, - renderedRoutes.handlerDefinitions, - responseDefinitions.flatten, - context.customExtraction, - context.authImplementation, - securityExposure - ) - extraRouteParams <- getExtraRouteParams( - resourceName, - context.customExtraction, - context.tracing, - context.authImplementation, - securityExposure - ) - classSrc <- renderClass( - resourceName, - handlerName, - renderedRoutes.classAnnotations, - renderedRoutes.routes, - extraRouteParams, - responseDefinitions.flatten, - renderedRoutes.supportDefinitions, - renderedRoutes.securitySchemesDefinitions, - context.customExtraction, - context.authImplementation - ) - } yield Server(className, frameworkImports ++ extraImports, handlerSrc, classSrc) - } - } yield Servers[L](servers, supportDefinitions) - } -} diff --git a/modules/core/src/main/scala/dev/guardrail/generators/Servers.scala b/modules/core/src/main/scala/dev/guardrail/generators/Servers.scala new file mode 100644 index 0000000000..1981d65358 --- /dev/null +++ b/modules/core/src/main/scala/dev/guardrail/generators/Servers.scala @@ -0,0 +1,17 @@ +package dev.guardrail.generators + +import dev.guardrail.core.SupportDefinition +import dev.guardrail.languages.LA + +case class Servers[L <: LA](servers: List[Server[L]], supportDefinitions: List[SupportDefinition[L]]) +case class Server[L <: LA](pkg: List[String], extraImports: List[L#Import], handlerDefinition: L#Definition, serverDefinitions: List[L#Definition]) +case class CustomExtractionField[L <: LA](param: LanguageParameter[L], term: L#Term) +case class TracingField[L <: LA](param: LanguageParameter[L], term: L#Term) +case class RenderedRoutes[L <: LA]( + routes: List[L#Statement], + classAnnotations: List[L#Annotation], + methodSigs: List[L#MethodDeclaration], + supportDefinitions: List[L#Definition], + handlerDefinitions: List[L#Statement], + securitySchemesDefinitions: List[L#Definition] +) diff --git a/modules/core/src/main/scala/dev/guardrail/terms/server/ServerTerms.scala b/modules/core/src/main/scala/dev/guardrail/terms/server/ServerTerms.scala index f225711606..80134045a2 100644 --- a/modules/core/src/main/scala/dev/guardrail/terms/server/ServerTerms.scala +++ b/modules/core/src/main/scala/dev/guardrail/terms/server/ServerTerms.scala @@ -1,16 +1,17 @@ package dev.guardrail.terms.server -import cats.Monad import cats.data.NonEmptyList -import io.swagger.v3.oas.models.Operation +import cats.Monad -import dev.guardrail.AuthImplementation -import dev.guardrail.core.{ SupportDefinition, Tracker } -import dev.guardrail.generators.{ CustomExtractionField, LanguageParameters, RenderedRoutes, TracingField } +import dev.guardrail._ import dev.guardrail.languages.LA import dev.guardrail.terms.Responses +import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.StrictProtocolElems -import dev.guardrail.terms.{ RouteMeta, SecurityScheme } +import dev.guardrail.terms.{ CollectionsLibTerms, LanguageTerms, RouteMeta, SecurityScheme, SwaggerTerms } +import dev.guardrail.core.Tracker +import io.swagger.v3.oas.models.Components +import dev.guardrail.generators._ case class GenerateRouteMeta[L <: LA]( operationId: String, @@ -32,166 +33,12 @@ object SecurityExposure { abstract class ServerTerms[L <: LA, F[_]] { self => def MonadF: Monad[F] - def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean): F[Option[CustomExtractionField[L]]] - def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean): F[Option[TracingField[L]]] - def generateRoutes( - tracing: Boolean, - resourceName: String, - handlerName: String, - basePath: Option[String], - routes: List[GenerateRouteMeta[L]], + + def fromSwagger(context: Context, supportPackage: NonEmptyList[String], basePath: Option[String], frameworkImports: List[L#Import])( + groupedRoutes: List[(List[String], List[RouteMeta])] + )( protocolElems: List[StrictProtocolElems[L]], securitySchemes: Map[String, SecurityScheme[L]], - securityExposure: SecurityExposure, - authImplementation: AuthImplementation - ): F[RenderedRoutes[L]] - def getExtraRouteParams( - resourceName: String, - customExtraction: Boolean, - tracing: Boolean, - authImplementation: AuthImplementation, - securityExposure: SecurityExposure - ): F[List[L#MethodParameter]] - def generateResponseDefinitions(responseClsName: String, responses: Responses[L], protocolElems: List[StrictProtocolElems[L]]): F[List[L#Definition]] - def generateSupportDefinitions(tracing: Boolean, securitySchemes: Map[String, SecurityScheme[L]]): F[List[SupportDefinition[L]]] - def renderClass( - resourceName: String, - handlerName: String, - annotations: List[L#Annotation], - combinedRouteTerms: List[L#Statement], - extraRouteParams: List[L#MethodParameter], - responseDefinitions: List[L#Definition], - supportDefinitions: List[L#Definition], - securitySchemesDefinitions: List[L#Definition], - customExtraction: Boolean, - authImplementation: AuthImplementation - ): F[List[L#Definition]] - def renderHandler( - handlerName: String, - methodSigs: List[L#MethodDeclaration], - handlerDefinitions: List[L#Statement], - responseDefinitions: List[L#Definition], - customExtraction: Boolean, - authImplementation: AuthImplementation, - securityExposure: SecurityExposure - ): F[L#Definition] - def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]): F[List[L#Import]] - - def copy( - MonadF: Monad[F] = self.MonadF, - buildCustomExtractionFields: (Tracker[Operation], List[String], Boolean) => F[Option[CustomExtractionField[L]]] = self.buildCustomExtractionFields _, - buildTracingFields: (Tracker[Operation], List[String], Boolean) => F[Option[TracingField[L]]] = self.buildTracingFields _, - generateRoutes: ( - Boolean, - String, - String, - Option[String], - List[GenerateRouteMeta[L]], - List[StrictProtocolElems[L]], - Map[String, SecurityScheme[L]], - SecurityExposure, - AuthImplementation - ) => F[RenderedRoutes[L]] = self.generateRoutes _, - getExtraRouteParams: (String, Boolean, Boolean, AuthImplementation, SecurityExposure) => F[List[L#MethodParameter]] = self.getExtraRouteParams _, - generateResponseDefinitions: (String, Responses[L], List[StrictProtocolElems[L]]) => F[List[L#Definition]] = self.generateResponseDefinitions _, - generateSupportDefinitions: (Boolean, Map[String, SecurityScheme[L]]) => F[List[SupportDefinition[L]]] = self.generateSupportDefinitions _, - renderClass: ( - String, - String, - List[L#Annotation], - List[L#Statement], - List[L#MethodParameter], - List[L#Definition], - List[L#Definition], - List[L#Definition], - Boolean, - AuthImplementation - ) => F[List[L#Definition]] = self.renderClass _, - renderHandler: ( - String, - List[L#MethodDeclaration], - List[L#Statement], - List[L#Definition], - Boolean, - AuthImplementation, - SecurityExposure - ) => F[L#Definition] = self.renderHandler _, - getExtraImports: (Boolean, NonEmptyList[String]) => F[List[L#Import]] = self.getExtraImports _ - ) = { - val newMonadF = MonadF - val newBuildCustomExtractionFields = buildCustomExtractionFields - val newBuildTracingFields = buildTracingFields - val newGenerateRoutes = generateRoutes - val newGetExtraRouteParams = getExtraRouteParams - val newGenerateResponseDefinitions = generateResponseDefinitions - val newGenerateSupportDefinitions = generateSupportDefinitions - val newRenderClass = renderClass - val newRenderHandler = renderHandler - val newGetExtraImports = getExtraImports - - new ServerTerms[L, F] { - def MonadF = newMonadF - def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean) = - newBuildCustomExtractionFields(operation, resourceName, customExtraction) - def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean) = - newBuildTracingFields(operation, resourceName, tracing) - def generateRoutes( - tracing: Boolean, - resourceName: String, - handlerName: String, - basePath: Option[String], - routes: List[GenerateRouteMeta[L]], - protocolElems: List[StrictProtocolElems[L]], - securitySchemes: Map[String, SecurityScheme[L]], - securityExposure: SecurityExposure, - authImplementation: AuthImplementation - ) = newGenerateRoutes(tracing, resourceName, handlerName, basePath, routes, protocolElems, securitySchemes, securityExposure, authImplementation) - def getExtraRouteParams( - resourceName: String, - customExtraction: Boolean, - tracing: Boolean, - authImplementation: AuthImplementation, - securityExposure: SecurityExposure - ) = - newGetExtraRouteParams(resourceName, customExtraction, tracing, authImplementation, securityExposure) - def generateResponseDefinitions(responseClsName: String, responses: Responses[L], protocolElems: List[StrictProtocolElems[L]]) = - newGenerateResponseDefinitions(responseClsName, responses, protocolElems) - def generateSupportDefinitions(tracing: Boolean, securitySchemes: Map[String, SecurityScheme[L]]) = - newGenerateSupportDefinitions(tracing, securitySchemes) - def renderClass( - resourceName: String, - handlerName: String, - annotations: List[L#Annotation], - combinedRouteTerms: List[L#Statement], - extraRouteParams: List[L#MethodParameter], - responseDefinitions: List[L#Definition], - supportDefinitions: List[L#Definition], - securitySchemesDefinitions: List[L#Definition], - customExtraction: Boolean, - authImplementation: AuthImplementation - ): F[List[L#Definition]] = - newRenderClass( - resourceName, - handlerName, - annotations, - combinedRouteTerms, - extraRouteParams, - responseDefinitions, - supportDefinitions, - securitySchemesDefinitions, - customExtraction, - authImplementation - ) - def renderHandler( - handlerName: String, - methodSigs: List[L#MethodDeclaration], - handlerDefinitions: List[L#Statement], - responseDefinitions: List[L#Definition], - customExtraction: Boolean, - authImplementation: AuthImplementation, - securityExposure: SecurityExposure - ) = newRenderHandler(handlerName, methodSigs, handlerDefinitions, responseDefinitions, customExtraction, authImplementation, securityExposure) - def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]) = newGetExtraImports(tracing, supportPackage) - } - } + components: Tracker[Option[Components]] + )(implicit Fw: FrameworkTerms[L, F], Sc: LanguageTerms[L, F], Cl: CollectionsLibTerms[L, F], Sw: SwaggerTerms[L, F]): F[Servers[L]] } diff --git a/modules/java-dropwizard/src/main/scala/dev/guardrail/generators/java/dropwizard/DropwizardServerGenerator.scala b/modules/java-dropwizard/src/main/scala/dev/guardrail/generators/java/dropwizard/DropwizardServerGenerator.scala index 556a758039..60fbd2fddd 100644 --- a/modules/java-dropwizard/src/main/scala/dev/guardrail/generators/java/dropwizard/DropwizardServerGenerator.scala +++ b/modules/java-dropwizard/src/main/scala/dev/guardrail/generators/java/dropwizard/DropwizardServerGenerator.scala @@ -4,51 +4,70 @@ import cats.Monad import cats.data.NonEmptyList import cats.syntax.all._ import com.github.javaparser.StaticJavaParser +import com.github.javaparser.ast.ImportDeclaration import com.github.javaparser.ast.Modifier.Keyword._ import com.github.javaparser.ast.Modifier._ -import com.github.javaparser.ast.`type`.{ ClassOrInterfaceType, PrimitiveType, Type, UnknownType, VoidType } +import com.github.javaparser.ast.Node +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.`type`.ClassOrInterfaceType +import com.github.javaparser.ast.`type`.PrimitiveType +import com.github.javaparser.ast.`type`.Type +import com.github.javaparser.ast.`type`.UnknownType +import com.github.javaparser.ast.`type`.VoidType import com.github.javaparser.ast.body._ -import com.github.javaparser.ast.expr.{ MethodCallExpr, _ } +import com.github.javaparser.ast.expr.MethodCallExpr +import com.github.javaparser.ast.expr._ import com.github.javaparser.ast.stmt._ -import com.github.javaparser.ast.{ ImportDeclaration, Node, NodeList } -import io.swagger.v3.oas.models.Operation -import scala.compat.java8.OptionConverters._ -import scala.concurrent.Future -import scala.language.existentials -import scala.reflect.runtime.universe.typeTag - import dev.guardrail.AuthImplementation +import dev.guardrail.Context import dev.guardrail.Target -import dev.guardrail.core.{ SupportDefinition, Tracker } +import dev.guardrail.core.SupportDefinition +import dev.guardrail.core.Tracker import dev.guardrail.core.extract.ServerRawResponse +import dev.guardrail.generators.CustomExtractionField import dev.guardrail.generators.LanguageParameter +import dev.guardrail.generators.RenderedRoutes +import dev.guardrail.generators.Server +import dev.guardrail.generators.Servers +import dev.guardrail.generators.TracingField import dev.guardrail.generators.java.JavaCollectionsGenerator import dev.guardrail.generators.java.JavaLanguage import dev.guardrail.generators.java.JavaVavrCollectionsGenerator import dev.guardrail.generators.java.SerializationHelpers import dev.guardrail.generators.java.syntax._ -import dev.guardrail.generators.spi.{ CollectionsGeneratorLoader, ModuleLoadResult, ServerGeneratorLoader } -import dev.guardrail.generators.{ CustomExtractionField, RenderedRoutes, TracingField } +import dev.guardrail.generators.spi.CollectionsGeneratorLoader +import dev.guardrail.generators.spi.ModuleLoadResult +import dev.guardrail.generators.spi.ServerGeneratorLoader import dev.guardrail.javaext.helpers.ResponseHelpers -import dev.guardrail.shims.OperationExt -import dev.guardrail.terms.collections.{ CollectionsAbstraction, JavaStdLibCollections, JavaVavrCollections } +import dev.guardrail.shims._ +import dev.guardrail.terms.ApplicationJson +import dev.guardrail.terms.BinaryContent +import dev.guardrail.terms.CollectionsLibTerms +import dev.guardrail.terms.ContentType +import dev.guardrail.terms.LanguageTerms +import dev.guardrail.terms.MultipartFormData +import dev.guardrail.terms.OctetStream +import dev.guardrail.terms.Response +import dev.guardrail.terms.Responses +import dev.guardrail.terms.RouteMeta +import dev.guardrail.terms.SecurityScheme +import dev.guardrail.terms.SwaggerTerms +import dev.guardrail.terms.TextContent +import dev.guardrail.terms.TextPlain +import dev.guardrail.terms.UrlencodedFormData +import dev.guardrail.terms.collections.CollectionsAbstraction +import dev.guardrail.terms.collections.JavaStdLibCollections +import dev.guardrail.terms.collections.JavaVavrCollections +import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.StrictProtocolElems import dev.guardrail.terms.server._ -import dev.guardrail.terms.{ - ApplicationJson, - BinaryContent, - CollectionsLibTerms, - ContentType, - MultipartFormData, - OctetStream, - Response, - Responses, - RouteMeta, - SecurityScheme, - TextContent, - TextPlain, - UrlencodedFormData -} +import io.swagger.v3.oas.models.Components +import io.swagger.v3.oas.models.Operation + +import scala.compat.java8.OptionConverters._ +import scala.concurrent.Future +import scala.language.existentials +import scala.reflect.runtime.universe.typeTag class DropwizardServerGeneratorLoader extends ServerGeneratorLoader { type L = JavaLanguage @@ -69,9 +88,103 @@ object DropwizardServerGenerator { @SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements", "org.wartremover.warts.Null")) class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, Target], Ca: CollectionsAbstraction[JavaLanguage]) extends ServerTerms[JavaLanguage, Target] { - import Ca._ - implicit def MonadF: Monad[Target] = Target.targetInstances + override implicit def MonadF: Monad[Target] = Target.targetInstances + + override def fromSwagger(context: Context, supportPackage: NonEmptyList[String], basePath: Option[String], frameworkImports: List[JavaLanguage#Import])( + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[JavaLanguage]], + securitySchemes: Map[String, SecurityScheme[JavaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[Servers[JavaLanguage]] = { + import Sw._ + import Sc._ + import dev.guardrail._ + + for { + extraImports <- getExtraImports(context.tracing, supportPackage) + supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) + servers <- groupedRoutes.traverse { case (className, unsortedRoutes) => + val routes = unsortedRoutes + .groupBy(_.path.unwrapTracker.indexOf('{')) + .view + .mapValues(_.sortBy(r => (r.path.unwrapTracker, r.method))) + .toList + .sortBy(_._1) + .flatMap(_._2) + for { + resourceName <- formatTypeName(className.lastOption.getOrElse(""), Some("Resource")) + handlerName <- formatTypeName(className.lastOption.getOrElse(""), Some("Handler")) + + responseServerPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => + for { + operationId <- getOperationId(operation) + responses <- Responses.getResponses(operationId, operation, protocolElems, components) + responseClsName <- formatTypeName(operationId, Some("Response")) + responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) + methodName <- formatMethodName(operationId) + parameters <- route.getParameters[JavaLanguage, Target](components, protocolElems) + customExtractionField <- buildCustomExtractionFields(operation, className, context.customExtraction) + tracingField <- buildTracingFields(operation, className, context.tracing) + } yield ( + responseDefinitions, + GenerateRouteMeta(operationId, methodName, responseClsName, customExtractionField, tracingField, route, parameters, responses) + ) + } + (responseDefinitions, serverOperations) = responseServerPair.unzip + securityExposure = serverOperations.flatMap(_.routeMeta.securityRequirements) match { + case Nil => SecurityExposure.Undefined + case xs => if (xs.exists(_.optional)) SecurityExposure.Optional else SecurityExposure.Required + } + renderedRoutes <- generateRoutes( + context.tracing, + resourceName, + handlerName, + basePath, + serverOperations, + protocolElems, + securitySchemes, + securityExposure, + context.authImplementation + ) + handlerSrc <- renderHandler( + handlerName, + renderedRoutes.methodSigs, + renderedRoutes.handlerDefinitions, + responseDefinitions.flatten, + context.customExtraction, + context.authImplementation, + securityExposure + ) + extraRouteParams <- getExtraRouteParams( + resourceName, + context.customExtraction, + context.tracing, + context.authImplementation, + securityExposure + ) + classSrc <- renderClass( + resourceName, + handlerName, + renderedRoutes.classAnnotations, + renderedRoutes.routes, + extraRouteParams, + responseDefinitions.flatten, + renderedRoutes.supportDefinitions, + renderedRoutes.securitySchemesDefinitions, + context.customExtraction, + context.authImplementation + ) + } yield Server[JavaLanguage](className, frameworkImports ++ extraImports, handlerSrc, classSrc) + } + } yield Servers[JavaLanguage](servers, supportDefinitions) + } @SuppressWarnings(Array("org.wartremover.warts.TripleQuestionMark")) private def toJaxRsAnnotationName: ContentType => Expression = { @@ -234,7 +347,7 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa } } - override def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]): Target[List[ImportDeclaration]] = + private def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]): Target[List[ImportDeclaration]] = List( "javax.inject.Inject", "javax.validation.constraints.NotNull", @@ -261,7 +374,7 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa "org.slf4j.LoggerFactory" ).traverse(safeParseRawImport) - override def buildCustomExtractionFields( + private def buildCustomExtractionFields( operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean @@ -272,14 +385,14 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa Target.pure(Option.empty) } - override def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean): Target[Option[TracingField[JavaLanguage]]] = + private def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean): Target[Option[TracingField[JavaLanguage]]] = if (tracing) { Target.raiseUserError(s"Tracing is not yet supported by this framework") } else { Target.pure(Option.empty) } - override def generateRoutes( + private def generateRoutes( tracing: Boolean, resourceName: String, handlerName: String, @@ -289,7 +402,9 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa securitySchemes: Map[String, SecurityScheme[JavaLanguage]], securityExposure: SecurityExposure, authImplementation: AuthImplementation - ): Target[RenderedRoutes[JavaLanguage]] = + ): Target[RenderedRoutes[JavaLanguage]] = { + import Ca._ + for { resourceType <- safeParseClassOrInterfaceType(resourceName) handlerType <- safeParseClassOrInterfaceType(handlerName) @@ -653,8 +768,9 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa RenderedRoutes[JavaLanguage](routeMethods, annotations, handlerMethodSigs, supportDefinitions, List.empty, List.empty) } + } - override def getExtraRouteParams( + private def getExtraRouteParams( resourceName: String, customExtraction: Boolean, tracing: Boolean, @@ -673,7 +789,7 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa } else Target.pure(List.empty) } yield customExtraction ::: tracing - override def generateResponseDefinitions( + private def generateResponseDefinitions( responseClsName: String, responses: Responses[JavaLanguage], protocolElems: List[StrictProtocolElems[JavaLanguage]] @@ -692,7 +808,7 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa abstractResponseClass :: Nil } - override def generateSupportDefinitions( + private def generateSupportDefinitions( tracing: Boolean, securitySchemes: Map[String, SecurityScheme[JavaLanguage]] ): Target[List[SupportDefinition[JavaLanguage]]] = @@ -726,7 +842,7 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa ) } - override def renderClass( + private def renderClass( className: String, handlerName: String, classAnnotations: List[com.github.javaparser.ast.expr.AnnotationExpr], @@ -742,7 +858,7 @@ class DropwizardServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLa safeParseSimpleName(handlerName) >> Target.pure(doRenderClass(className, classAnnotations, supportDefinitions, combinedRouteTerms) :: Nil) - override def renderHandler( + private def renderHandler( handlerName: String, methodSigs: List[com.github.javaparser.ast.body.MethodDeclaration], handlerDefinitions: List[com.github.javaparser.ast.Node], diff --git a/modules/java-spring-mvc/src/main/scala/dev/guardrail/generators/java/springMvc/SpringMvcServerGenerator.scala b/modules/java-spring-mvc/src/main/scala/dev/guardrail/generators/java/springMvc/SpringMvcServerGenerator.scala index a79bed69b3..4c8dc60be7 100644 --- a/modules/java-spring-mvc/src/main/scala/dev/guardrail/generators/java/springMvc/SpringMvcServerGenerator.scala +++ b/modules/java-spring-mvc/src/main/scala/dev/guardrail/generators/java/springMvc/SpringMvcServerGenerator.scala @@ -6,50 +6,67 @@ import cats.syntax.all._ import com.github.javaparser.StaticJavaParser import com.github.javaparser.ast.Modifier.Keyword._ import com.github.javaparser.ast.Modifier._ -import com.github.javaparser.ast.`type`.{ ClassOrInterfaceType, PrimitiveType, Type } +import com.github.javaparser.ast.Node +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.`type`.ClassOrInterfaceType +import com.github.javaparser.ast.`type`.PrimitiveType +import com.github.javaparser.ast.`type`.Type import com.github.javaparser.ast.body._ import com.github.javaparser.ast.expr._ import com.github.javaparser.ast.stmt._ -import com.github.javaparser.ast.{ Node, NodeList } -import io.swagger.v3.oas.models.Operation -import io.swagger.v3.oas.models.responses.ApiResponse -import scala.compat.java8.OptionConverters._ -import scala.language.existentials -import scala.reflect.runtime.universe.typeTag -import scala.util.Try - import dev.guardrail.AuthImplementation +import dev.guardrail.Context +import dev.guardrail.Target import dev.guardrail.core.Tracker import dev.guardrail.core.extract.ServerRawResponse +import dev.guardrail.generators.LanguageParameter +import dev.guardrail.generators.LanguageParameters import dev.guardrail.generators.RenderedRoutes +import dev.guardrail.generators.Server +import dev.guardrail.generators.Servers import dev.guardrail.generators.java.JavaCollectionsGenerator import dev.guardrail.generators.java.JavaLanguage import dev.guardrail.generators.java.JavaVavrCollectionsGenerator import dev.guardrail.generators.java.SerializationHelpers import dev.guardrail.generators.java.syntax._ -import dev.guardrail.generators.spi.{ CollectionsGeneratorLoader, ModuleLoadResult, ServerGeneratorLoader } -import dev.guardrail.generators.{ LanguageParameter, LanguageParameters } +import dev.guardrail.generators.spi.CollectionsGeneratorLoader +import dev.guardrail.generators.spi.ModuleLoadResult +import dev.guardrail.generators.spi.ServerGeneratorLoader +import dev.guardrail.shims._ +import dev.guardrail.terms.AnyContentType +import dev.guardrail.terms.ApplicationJson +import dev.guardrail.terms.BinaryContent +import dev.guardrail.terms.CollectionsLibTerms +import dev.guardrail.terms.ContentType +import dev.guardrail.terms.LanguageTerms +import dev.guardrail.terms.MultipartFormData +import dev.guardrail.terms.OctetStream +import dev.guardrail.terms.Response +import dev.guardrail.terms.Responses +import dev.guardrail.terms.RouteMeta +import dev.guardrail.terms.SecurityScheme +import dev.guardrail.terms.SwaggerTerms +import dev.guardrail.terms.TextContent +import dev.guardrail.terms.TextPlain +import dev.guardrail.terms.UrlencodedFormData +import dev.guardrail.terms.collections.CollectionsAbstraction +import dev.guardrail.terms.collections.JavaStdLibCollections +import dev.guardrail.terms.collections.JavaVavrCollections +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol.ADT +import dev.guardrail.terms.protocol.ClassDefinition +import dev.guardrail.terms.protocol.EnumDefinition +import dev.guardrail.terms.protocol.RandomType +import dev.guardrail.terms.protocol.StrictProtocolElems import dev.guardrail.terms.server._ -import dev.guardrail.shims.OperationExt -import dev.guardrail.terms.collections.{ CollectionsAbstraction, JavaStdLibCollections, JavaVavrCollections } -import dev.guardrail.terms.{ - AnyContentType, - ApplicationJson, - BinaryContent, - CollectionsLibTerms, - ContentType, - MultipartFormData, - OctetStream, - Response, - Responses, - RouteMeta, - SecurityScheme, - TextContent, - TextPlain, - UrlencodedFormData -} -import dev.guardrail.Target -import dev.guardrail.terms.protocol.{ ADT, ClassDefinition, EnumDefinition, RandomType, StrictProtocolElems } +import io.swagger.v3.oas.models.Components +import io.swagger.v3.oas.models.Operation +import io.swagger.v3.oas.models.responses.ApiResponse + +import scala.compat.java8.OptionConverters._ +import scala.language.existentials +import scala.reflect.runtime.universe.typeTag +import scala.util.Try class SpringMvcServerGeneratorLoader extends ServerGeneratorLoader { type L = JavaLanguage @@ -70,10 +87,104 @@ object SpringMvcServerGenerator { @SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements", "org.wartremover.warts.Null")) class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, Target], Ca: CollectionsAbstraction[JavaLanguage]) extends ServerTerms[JavaLanguage, Target] { - import Ca._ override implicit def MonadF: Monad[Target] = Target.targetInstances + override def fromSwagger(context: Context, supportPackage: NonEmptyList[String], basePath: Option[String], frameworkImports: List[JavaLanguage#Import])( + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[JavaLanguage]], + securitySchemes: Map[String, SecurityScheme[JavaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[Servers[JavaLanguage]] = { + import Sw._ + import Sc._ + import dev.guardrail._ + + for { + extraImports <- getExtraImports(context.tracing, supportPackage) + supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) + servers <- groupedRoutes.traverse { case (className, unsortedRoutes) => + val routes = unsortedRoutes + .groupBy(_.path.unwrapTracker.indexOf('{')) + .view + .mapValues(_.sortBy(r => (r.path.unwrapTracker, r.method))) + .toList + .sortBy(_._1) + .flatMap(_._2) + for { + resourceName <- formatTypeName(className.lastOption.getOrElse(""), Some("Resource")) + handlerName <- formatTypeName(className.lastOption.getOrElse(""), Some("Handler")) + + responseServerPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => + for { + operationId <- getOperationId(operation) + responses <- Responses.getResponses(operationId, operation, protocolElems, components) + responseClsName <- formatTypeName(operationId, Some("Response")) + responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) + methodName <- formatMethodName(operationId) + parameters <- route.getParameters[JavaLanguage, Target](components, protocolElems) + customExtractionField <- buildCustomExtractionFields(operation, className, context.customExtraction) + tracingField <- buildTracingFields(operation, className, context.tracing) + } yield ( + responseDefinitions, + GenerateRouteMeta(operationId, methodName, responseClsName, customExtractionField, tracingField, route, parameters, responses) + ) + } + (responseDefinitions, serverOperations) = responseServerPair.unzip + securityExposure = serverOperations.flatMap(_.routeMeta.securityRequirements) match { + case Nil => SecurityExposure.Undefined + case xs => if (xs.exists(_.optional)) SecurityExposure.Optional else SecurityExposure.Required + } + renderedRoutes <- generateRoutes( + context.tracing, + resourceName, + handlerName, + basePath, + serverOperations, + protocolElems, + securitySchemes, + securityExposure, + context.authImplementation + ) + handlerSrc <- renderHandler( + handlerName, + renderedRoutes.methodSigs, + renderedRoutes.handlerDefinitions, + responseDefinitions.flatten, + context.customExtraction, + context.authImplementation, + securityExposure + ) + extraRouteParams <- getExtraRouteParams( + resourceName, + context.customExtraction, + context.tracing, + context.authImplementation, + securityExposure + ) + classSrc <- renderClass( + resourceName, + handlerName, + renderedRoutes.classAnnotations, + renderedRoutes.routes, + extraRouteParams, + responseDefinitions.flatten, + renderedRoutes.supportDefinitions, + renderedRoutes.securitySchemesDefinitions, + context.customExtraction, + context.authImplementation + ) + } yield Server[JavaLanguage](className, frameworkImports ++ extraImports, handlerSrc, classSrc) + } + } yield Servers[JavaLanguage](servers, supportDefinitions) + } + @SuppressWarnings(Array("org.wartremover.warts.TripleQuestionMark")) private def toSpringMediaType: ContentType => Expression = { case _: ApplicationJson => new FieldAccessExpr(new NameExpr("MediaType"), "APPLICATION_JSON_VALUE") @@ -308,7 +419,7 @@ class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLan } } - override def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]) = + private def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]) = List( "java.util.Optional", "java.util.concurrent.CompletionStage", @@ -325,21 +436,21 @@ class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLan "org.springframework.web.multipart.MultipartFile" ).traverse(safeParseRawImport) - override def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean) = + private def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean) = if (customExtraction) { Target.raiseUserError(s"Custom Extraction is not yet supported by this framework") } else { Target.pure(Option.empty) } - override def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean) = + private def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean) = if (tracing) { Target.raiseUserError(s"Tracing is not yet supported by this framework") } else { Target.pure(Option.empty) } - override def generateRoutes( + private def generateRoutes( tracing: Boolean, resourceName: String, handlerName: String, @@ -354,6 +465,7 @@ class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLan resourceType <- safeParseClassOrInterfaceType(resourceName) handlerType <- safeParseClassOrInterfaceType(handlerName) } yield { + import Ca._ val basePathComponents = basePath.toList.flatMap(splitPathComponents) val commonPathPrefix = findPathPrefix(routes.map(_.routeMeta.path.unwrapTracker)) @@ -718,7 +830,7 @@ class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLan RenderedRoutes[JavaLanguage](routeMethods, annotations, handlerMethodSigs, supportDefinitions, List.empty, List.empty) } - override def getExtraRouteParams( + private def getExtraRouteParams( resourceName: String, customExtraction: Boolean, tracing: Boolean, @@ -737,7 +849,7 @@ class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLan } else Target.pure(List.empty) } yield customExtraction ::: tracing - override def generateResponseDefinitions( + private def generateResponseDefinitions( responseClsName: String, responses: Responses[JavaLanguage], protocolElems: List[StrictProtocolElems[JavaLanguage]] @@ -755,7 +867,7 @@ class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLan abstractResponseClass :: Nil } - override def generateSupportDefinitions( + private def generateSupportDefinitions( tracing: Boolean, securitySchemes: Map[String, SecurityScheme[JavaLanguage]] ) = @@ -763,7 +875,7 @@ class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLan shower <- SerializationHelpers.showerSupportDef } yield List(shower) - override def renderClass( + private def renderClass( className: String, handlerName: String, classAnnotations: List[com.github.javaparser.ast.expr.AnnotationExpr], @@ -779,7 +891,7 @@ class SpringMvcServerGenerator private (implicit Cl: CollectionsLibTerms[JavaLan safeParseSimpleName(handlerName) >> Target.pure(doRenderClass(className, classAnnotations, supportDefinitions, combinedRouteTerms) :: Nil) - override def renderHandler( + private def renderHandler( handlerName: String, methodSigs: List[com.github.javaparser.ast.body.MethodDeclaration], handlerDefinitions: List[com.github.javaparser.ast.Node], diff --git a/modules/scala-akka-http/src/main/scala/dev/guardrail/generators/scala/akkaHttp/AkkaHttpServerGenerator.scala b/modules/scala-akka-http/src/main/scala/dev/guardrail/generators/scala/akkaHttp/AkkaHttpServerGenerator.scala index d0f6a34d01..9c586441aa 100644 --- a/modules/scala-akka-http/src/main/scala/dev/guardrail/generators/scala/akkaHttp/AkkaHttpServerGenerator.scala +++ b/modules/scala-akka-http/src/main/scala/dev/guardrail/generators/scala/akkaHttp/AkkaHttpServerGenerator.scala @@ -1,27 +1,59 @@ package dev.guardrail.generators.scala.akkaHttp +import _root_.io.swagger.v3.oas.models.Components import _root_.io.swagger.v3.oas.models.Operation import _root_.io.swagger.v3.oas.models.PathItem.HttpMethod import cats.Monad import cats.data.NonEmptyList import cats.implicits._ -import scala.meta._ -import scala.reflect.runtime.universe.typeTag -import dev.guardrail.AuthImplementation -import dev.guardrail.core.extract.{ ServerRawResponse, TracingLabel } -import dev.guardrail.core.{ LiteralRawType, MapRawType, ReifiedRawType, Tracker, VectorRawType } +import dev.guardrail._ +import dev.guardrail.core.LiteralRawType +import dev.guardrail.core.MapRawType +import dev.guardrail.core.ReifiedRawType +import dev.guardrail.core.Tracker +import dev.guardrail.core.VectorRawType +import dev.guardrail.core.extract.ServerRawResponse +import dev.guardrail.core.extract.TracingLabel +import dev.guardrail.generators.CustomExtractionField +import dev.guardrail.generators.LanguageParameter +import dev.guardrail.generators.RawParameterName +import dev.guardrail.generators.RenderedRoutes +import dev.guardrail.generators.Server +import dev.guardrail.generators.Servers +import dev.guardrail.generators.TracingField import dev.guardrail.generators.operations.TracingLabelFormatter -import dev.guardrail.generators.scala.{ CirceModelGenerator, CirceRefinedModelGenerator, JacksonModelGenerator, ModelGeneratorType, ScalaLanguage } +import dev.guardrail.generators.scala.CirceModelGenerator +import dev.guardrail.generators.scala.CirceRefinedModelGenerator +import dev.guardrail.generators.scala.JacksonModelGenerator +import dev.guardrail.generators.scala.ModelGeneratorType +import dev.guardrail.generators.scala.ScalaLanguage import dev.guardrail.generators.scala.syntax._ -import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader, ServerGeneratorLoader } +import dev.guardrail.generators.spi.ModuleLoadResult +import dev.guardrail.generators.spi.ProtocolGeneratorLoader +import dev.guardrail.generators.spi.ServerGeneratorLoader import dev.guardrail.generators.syntax._ -import dev.guardrail.generators.{ CustomExtractionField, LanguageParameter, RawParameterName, RenderedRoutes, TracingField } import dev.guardrail.shims._ +import dev.guardrail.terms.ApplicationJson +import dev.guardrail.terms.BinaryContent +import dev.guardrail.terms.CollectionsLibTerms +import dev.guardrail.terms.ContentType +import dev.guardrail.terms.Header +import dev.guardrail.terms.LanguageTerms +import dev.guardrail.terms.MultipartFormData +import dev.guardrail.terms.Response +import dev.guardrail.terms.Responses +import dev.guardrail.terms.RouteMeta +import dev.guardrail.terms.SecurityScheme +import dev.guardrail.terms.SwaggerTerms +import dev.guardrail.terms.TextContent +import dev.guardrail.terms.TextPlain +import dev.guardrail.terms.UrlencodedFormData +import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol._ import dev.guardrail.terms.server._ -import dev.guardrail.terms.{ ApplicationJson, BinaryContent, ContentType, Header, MultipartFormData, Response } -import dev.guardrail.terms.{ Responses, RouteMeta, SecurityScheme, TextContent, TextPlain, UrlencodedFormData } -import dev.guardrail.{ Target, UserError } + +import scala.meta._ +import scala.reflect.runtime.universe.typeTag class AkkaHttpServerGeneratorLoader extends ServerGeneratorLoader { type L = ScalaLanguage @@ -84,13 +116,108 @@ object AkkaHttpServerGenerator { class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGeneratorType: ModelGeneratorType) extends ServerTerms[ScalaLanguage, Target] { val customExtractionTypeName: Type.Name = Type.Name("E") - def splitOperationParts(operationId: String): (List[String], String) = { + implicit def MonadF: Monad[Target] = Target.targetInstances + + override def fromSwagger(context: Context, supportPackage: NonEmptyList[String], basePath: Option[String], frameworkImports: List[ScalaLanguage#Import])( + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[ScalaLanguage]], + securitySchemes: Map[String, SecurityScheme[ScalaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Servers[ScalaLanguage]] = { + import Sw._ + import Sc._ + + for { + extraImports <- getExtraImports(context.tracing, supportPackage) + supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) + servers <- groupedRoutes.traverse { case (className, unsortedRoutes) => + val routes = unsortedRoutes + .groupBy(_.path.unwrapTracker.indexOf('{')) + .view + .mapValues(_.sortBy(r => (r.path.unwrapTracker, r.method))) + .toList + .sortBy(_._1) + .flatMap(_._2) + for { + resourceName <- formatTypeName(className.lastOption.getOrElse(""), Some("Resource")) + handlerName <- formatTypeName(className.lastOption.getOrElse(""), Some("Handler")) + + responseServerPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => + for { + operationId <- getOperationId(operation) + responses <- Responses.getResponses(operationId, operation, protocolElems, components) + responseClsName <- formatTypeName(operationId, Some("Response")) + responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) + methodName <- formatMethodName(operationId) + parameters <- route.getParameters[ScalaLanguage, Target](components, protocolElems) + customExtractionField <- buildCustomExtractionFields(operation, className, context.customExtraction) + tracingField <- buildTracingFields(operation, className, context.tracing) + } yield ( + responseDefinitions, + GenerateRouteMeta(operationId, methodName, responseClsName, customExtractionField, tracingField, route, parameters, responses) + ) + } + (responseDefinitions, serverOperations) = responseServerPair.unzip + securityExposure = serverOperations.flatMap(_.routeMeta.securityRequirements) match { + case Nil => SecurityExposure.Undefined + case xs => if (xs.exists(_.optional)) SecurityExposure.Optional else SecurityExposure.Required + } + renderedRoutes <- generateRoutes( + context.tracing, + resourceName, + handlerName, + basePath, + serverOperations, + protocolElems, + securitySchemes, + securityExposure, + context.authImplementation + ) + handlerSrc <- renderHandler( + handlerName, + renderedRoutes.methodSigs, + renderedRoutes.handlerDefinitions, + responseDefinitions.flatten, + context.customExtraction, + context.authImplementation, + securityExposure + ) + extraRouteParams <- getExtraRouteParams( + resourceName, + context.customExtraction, + context.tracing, + context.authImplementation, + securityExposure + ) + classSrc <- renderClass( + resourceName, + handlerName, + renderedRoutes.classAnnotations, + renderedRoutes.routes, + extraRouteParams, + responseDefinitions.flatten, + renderedRoutes.supportDefinitions, + renderedRoutes.securitySchemesDefinitions, + context.customExtraction, + context.authImplementation + ) + } yield Server[ScalaLanguage](className, frameworkImports ++ extraImports, handlerSrc, classSrc) + } + } yield Servers[ScalaLanguage](servers, supportDefinitions) + } + + private def splitOperationParts(operationId: String): (List[String], String) = { val parts = operationId.split('.') (parts.drop(1).toList, parts.last) } - implicit def MonadF: Monad[Target] = Target.targetInstances - def generateResponseDefinitions( + private def generateResponseDefinitions( responseClsName: String, responses: Responses[ScalaLanguage], protocolElems: List[StrictProtocolElems[ScalaLanguage]] @@ -192,7 +319,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe companion ) - def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean) = + private def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean) = for { _ <- Target.log.debug(s"buildCustomExtractionFields(${operation.unwrapTracker.showNotNull}, ${resourceName}, ${customExtraction})") res <- @@ -212,7 +339,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe } else Target.pure(None) } yield res - def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean) = + private def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean) = for { _ <- Target.log.debug(s"buildTracingFields(${operation.unwrapTracker.showNotNull}, ${resourceName}, ${tracing})") res <- @@ -231,7 +358,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe } yield Some(TracingField[ScalaLanguage](LanguageParameter.fromParam(param"traceBuilder: TraceBuilder"), q"""trace(${label})""")) } else Target.pure(None) } yield res - def generateRoutes( + private def generateRoutes( tracing: Boolean, resourceName: String, handlerName: String, @@ -255,7 +382,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe renderedRoutes.flatMap(_.handlerDefinitions), List.empty ) - def renderHandler( + private def renderHandler( handlerName: String, methodSigs: List[scala.meta.Decl.Def], handlerDefinitions: List[scala.meta.Stat], @@ -274,7 +401,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe } """ } - def getExtraRouteParams( + private def getExtraRouteParams( resourceName: String, customExtraction: Boolean, tracing: Boolean, @@ -292,7 +419,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe Target.pure(List(param"""trace: String => Directive1[TraceBuilder]""")) } else Target.pure(List.empty) } yield extractParam ::: traceParam - def renderClass( + private def renderClass( resourceName: String, handlerName: String, annotations: List[scala.meta.Mod.Annot], @@ -330,7 +457,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe } """) } - def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]) = + private def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]) = for { _ <- Target.log.debug(s"getExtraImports(${tracing})") } yield List( @@ -338,13 +465,13 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe Option(q"import scala.language.higherKinds") ).flatten - def generateSupportDefinitions( + private def generateSupportDefinitions( tracing: Boolean, securitySchemes: Map[String, SecurityScheme[ScalaLanguage]] ) = Target.pure(List.empty) - def httpMethodToAkka(method: HttpMethod): Target[Term] = method match { + private def httpMethodToAkka(method: HttpMethod): Target[Term] = method match { case HttpMethod.DELETE => Target.pure(q"delete") case HttpMethod.GET => Target.pure(q"get") case HttpMethod.PATCH => Target.pure(q"patch") @@ -355,7 +482,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe case other => Target.raiseUserError(s"Unknown method: ${other}") } - def pathStrToAkka( + private def pathStrToAkka( basePath: Option[String], path: Tracker[String], pathArgs: List[LanguageParameter[ScalaLanguage]] @@ -373,13 +500,13 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe } } - def findInnerTpe(rawType: ReifiedRawType): LiteralRawType = rawType match { + private def findInnerTpe(rawType: ReifiedRawType): LiteralRawType = rawType match { case x: LiteralRawType => x case VectorRawType(inner) => findInnerTpe(inner) case MapRawType(inner) => findInnerTpe(inner) } - def directivesFromParams( + private def directivesFromParams( required: Term => Type => Option[Term] => Target[Term], multi: Term => Type => Option[Term] => (Term => Term) => Target[Term], multiOpt: Term => Type => Option[Term] => (Term => Term) => Target[Term], @@ -429,14 +556,14 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe Some((xs.foldLeft[Term](x) { case (a, n) => q"${a} & ${n}" }, params.map(_.paramName))) } - def bodyToAkka(methodName: String, body: Option[LanguageParameter[ScalaLanguage]]): Target[Option[Term]] = + private def bodyToAkka(methodName: String, body: Option[LanguageParameter[ScalaLanguage]]): Target[Option[Term]] = Target.pure( body.map { case LanguageParameter(_, _, _, _, argType) => q"entity(as[${argType}](${Term.Name(s"${methodName}Decoder")}))" } ) - def headersToAkka: List[LanguageParameter[ScalaLanguage]] => Target[Option[(Term, List[Term.Name])]] = + private def headersToAkka: List[LanguageParameter[ScalaLanguage]] => Target[Option[(Term, List[Term.Name])]] = directivesFromParams( arg => { case t"String" => @@ -481,7 +608,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe } ) _ - def qsToAkka: List[LanguageParameter[ScalaLanguage]] => Target[Option[(Term, List[Term.Name])]] = { + private def qsToAkka: List[LanguageParameter[ScalaLanguage]] => Target[Option[(Term, List[Term.Name])]] = { type Unmarshaller = Term type Arg = Term val nameReceptacle: Arg => Type => Term = arg => tpe => q"Symbol(${arg}).as[${tpe}]" @@ -504,7 +631,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe override def toString(): String = s"Binding($value)" } - def formToAkka(consumes: Tracker[NonEmptyList[ContentType]], methodName: String)( + private def formToAkka(consumes: Tracker[NonEmptyList[ContentType]], methodName: String)( params: List[LanguageParameter[ScalaLanguage]] ): Target[(Option[Term], List[Stat])] = Target.log.function("formToAkka") { for { @@ -862,7 +989,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe } yield result.getOrElse((Option.empty[Term], List.empty[Stat])) } - case class RenderedRoute(route: Term, methodSig: Decl.Def, supportDefinitions: List[Defn], handlerDefinitions: List[Stat]) + private case class RenderedRoute(route: Term, methodSig: Decl.Def, supportDefinitions: List[Defn], handlerDefinitions: List[Stat]) private def generateRoute( resourceName: String, @@ -977,7 +1104,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe ) } - def generateHeaderParams(headers: List[Header[ScalaLanguage]], prefix: Term.Name): Term = { + private def generateHeaderParams(headers: List[Header[ScalaLanguage]], prefix: Term.Name): Term = { def liftOptionTerm(tParamName: Term.Name, tName: RawParameterName) = q"$prefix.$tParamName.map(v => RawHeader(${tName.toLit}, Formatter.show(v)))" @@ -993,13 +1120,13 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe q"scala.collection.immutable.Seq[Option[HttpHeader]](..$args).flatten" } - def combineRouteTerms(terms: List[Term]): Target[Term] = + private def combineRouteTerms(terms: List[Term]): Target[Term] = Target.log.function(s"combineRouteTerms(<${terms.length} routes>)")(for { routes <- Target.fromOption(NonEmptyList.fromList(terms), UserError("Generated no routes, no source to generate")) _ <- routes.traverse(route => Target.log.debug(route.toString)) } yield routes.tail.foldLeft(routes.head) { case (a, n) => q"${a} ~ ${n}" }) - def generateCodecs( + private def generateCodecs( methodName: String, bodyArgs: Option[LanguageParameter[ScalaLanguage]], responses: Responses[ScalaLanguage], @@ -1007,7 +1134,7 @@ class AkkaHttpServerGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGe ): Target[List[Defn.Def]] = generateDecoders(methodName, bodyArgs, consumes) - def generateDecoders( + private def generateDecoders( methodName: String, bodyArgs: Option[LanguageParameter[ScalaLanguage]], consumes: Tracker[NonEmptyList[ContentType]] diff --git a/modules/scala-dropwizard/src/main/scala/dev/guardrail/generators/scala/dropwizard/DropwizardServerGenerator.scala b/modules/scala-dropwizard/src/main/scala/dev/guardrail/generators/scala/dropwizard/DropwizardServerGenerator.scala index 09e90653c3..afb7b7c8c1 100644 --- a/modules/scala-dropwizard/src/main/scala/dev/guardrail/generators/scala/dropwizard/DropwizardServerGenerator.scala +++ b/modules/scala-dropwizard/src/main/scala/dev/guardrail/generators/scala/dropwizard/DropwizardServerGenerator.scala @@ -1,36 +1,46 @@ package dev.guardrail.generators.scala.dropwizard +import _root_.io.swagger.v3.oas.models.Components import cats.Monad import cats.data.NonEmptyList import cats.syntax.all._ -import io.swagger.v3.oas.models.Operation -import scala.meta._ -import scala.reflect.runtime.universe.typeTag - -import dev.guardrail.AuthImplementation -import dev.guardrail.Target -import dev.guardrail.core.{ SupportDefinition, Tracker } -import dev.guardrail.generators.{ CustomExtractionField, LanguageParameter, RawParameterName, RenderedRoutes, TracingField } +import dev.guardrail._ +import dev.guardrail.core.SupportDefinition +import dev.guardrail.core.Tracker +import dev.guardrail.generators.CustomExtractionField +import dev.guardrail.generators.LanguageParameter +import dev.guardrail.generators.RawParameterName +import dev.guardrail.generators.RenderedRoutes +import dev.guardrail.generators.Server +import dev.guardrail.generators.Servers +import dev.guardrail.generators.TracingField import dev.guardrail.generators.scala.ScalaLanguage -import dev.guardrail.generators.spi.{ ModuleLoadResult, ServerGeneratorLoader } +import dev.guardrail.generators.spi.ModuleLoadResult +import dev.guardrail.generators.spi.ServerGeneratorLoader import dev.guardrail.scalaext.helpers.ResponseHelpers -import dev.guardrail.shims.OperationExt +import dev.guardrail.shims._ +import dev.guardrail.terms.AnyContentType +import dev.guardrail.terms.ApplicationJson +import dev.guardrail.terms.BinaryContent +import dev.guardrail.terms.CollectionsLibTerms +import dev.guardrail.terms.ContentType +import dev.guardrail.terms.LanguageTerms +import dev.guardrail.terms.MultipartFormData +import dev.guardrail.terms.OctetStream +import dev.guardrail.terms.Responses +import dev.guardrail.terms.RouteMeta +import dev.guardrail.terms.SecurityScheme +import dev.guardrail.terms.SwaggerTerms +import dev.guardrail.terms.TextContent +import dev.guardrail.terms.TextPlain +import dev.guardrail.terms.UrlencodedFormData +import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.StrictProtocolElems -import dev.guardrail.terms.server.{ GenerateRouteMeta, SecurityExposure, ServerTerms } -import dev.guardrail.terms.{ - AnyContentType, - ApplicationJson, - BinaryContent, - ContentType, - MultipartFormData, - OctetStream, - Responses, - RouteMeta, - SecurityScheme, - TextContent, - TextPlain, - UrlencodedFormData -} +import dev.guardrail.terms.server._ +import io.swagger.v3.oas.models.Operation + +import scala.meta._ +import scala.reflect.runtime.universe.typeTag class DropwizardServerGeneratorLoader extends ServerGeneratorLoader { type L = ScalaLanguage @@ -46,6 +56,100 @@ object DropwizardServerGenerator { class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Target] { override def MonadF: Monad[Target] = Target.targetInstances + override def fromSwagger(context: Context, supportPackage: NonEmptyList[String], basePath: Option[String], frameworkImports: List[ScalaLanguage#Import])( + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[ScalaLanguage]], + securitySchemes: Map[String, SecurityScheme[ScalaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Servers[ScalaLanguage]] = { + import Sw._ + import Sc._ + + for { + extraImports <- getExtraImports(context.tracing, supportPackage) + supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) + servers <- groupedRoutes.traverse { case (className, unsortedRoutes) => + val routes = unsortedRoutes + .groupBy(_.path.unwrapTracker.indexOf('{')) + .view + .mapValues(_.sortBy(r => (r.path.unwrapTracker, r.method))) + .toList + .sortBy(_._1) + .flatMap(_._2) + for { + resourceName <- formatTypeName(className.lastOption.getOrElse(""), Some("Resource")) + handlerName <- formatTypeName(className.lastOption.getOrElse(""), Some("Handler")) + + responseServerPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => + for { + operationId <- getOperationId(operation) + responses <- Responses.getResponses(operationId, operation, protocolElems, components) + responseClsName <- formatTypeName(operationId, Some("Response")) + responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) + methodName <- formatMethodName(operationId) + parameters <- route.getParameters[ScalaLanguage, Target](components, protocolElems) + customExtractionField <- buildCustomExtractionFields(operation, className, context.customExtraction) + tracingField <- buildTracingFields(operation, className, context.tracing) + } yield ( + responseDefinitions, + GenerateRouteMeta(operationId, methodName, responseClsName, customExtractionField, tracingField, route, parameters, responses) + ) + } + (responseDefinitions, serverOperations) = responseServerPair.unzip + securityExposure = serverOperations.flatMap(_.routeMeta.securityRequirements) match { + case Nil => SecurityExposure.Undefined + case xs => if (xs.exists(_.optional)) SecurityExposure.Optional else SecurityExposure.Required + } + renderedRoutes <- generateRoutes( + context.tracing, + resourceName, + handlerName, + basePath, + serverOperations, + protocolElems, + securitySchemes, + securityExposure, + context.authImplementation + ) + handlerSrc <- renderHandler( + handlerName, + renderedRoutes.methodSigs, + renderedRoutes.handlerDefinitions, + responseDefinitions.flatten, + context.customExtraction, + context.authImplementation, + securityExposure + ) + extraRouteParams <- getExtraRouteParams( + resourceName, + context.customExtraction, + context.tracing, + context.authImplementation, + securityExposure + ) + classSrc <- renderClass( + resourceName, + handlerName, + renderedRoutes.classAnnotations, + renderedRoutes.routes, + extraRouteParams, + responseDefinitions.flatten, + renderedRoutes.supportDefinitions, + renderedRoutes.securitySchemesDefinitions, + context.customExtraction, + context.authImplementation + ) + } yield Server[ScalaLanguage](className, frameworkImports ++ extraImports, handlerSrc, classSrc) + } + } yield Servers[ScalaLanguage](servers, supportDefinitions) + } + private val buildTermSelect: NonEmptyList[String] => Term.Ref = { case NonEmptyList(head, tail) => tail.map(Term.Name.apply _).foldLeft[Term.Ref](Term.Name(head))(Term.Select.apply _) } @@ -161,7 +265,7 @@ class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Targe buildTransformers(param, httpParameterAnnotation).foldLeft(param.param)((accum, next) => next(accum)) } - override def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]): Target[List[Import]] = + private def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]): Target[List[Import]] = Target.pure( List( q"import io.dropwizard.jersey.PATCH", @@ -180,7 +284,7 @@ class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Targe ) ) - override def generateSupportDefinitions( + private def generateSupportDefinitions( tracing: Boolean, securitySchemes: Map[String, SecurityScheme[ScalaLanguage]] ): Target[List[SupportDefinition[ScalaLanguage]]] = @@ -253,7 +357,7 @@ class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Targe ) ) - override def buildCustomExtractionFields( + private def buildCustomExtractionFields( operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean @@ -264,10 +368,10 @@ class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Targe Target.pure(Option.empty) } - override def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean): Target[Option[TracingField[ScalaLanguage]]] = + private def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean): Target[Option[TracingField[ScalaLanguage]]] = Target.pure(None) - override def generateResponseDefinitions( + private def generateResponseDefinitions( responseClsName: String, responses: Responses[ScalaLanguage], protocolElems: List[StrictProtocolElems[ScalaLanguage]] @@ -301,7 +405,7 @@ class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Targe ) } - override def generateRoutes( + private def generateRoutes( tracing: Boolean, resourceName: String, handlerName: String, @@ -437,7 +541,7 @@ class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Targe ) } - override def getExtraRouteParams( + private def getExtraRouteParams( resourceName: String, customExtraction: Boolean, tracing: Boolean, @@ -456,7 +560,7 @@ class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Targe } else Target.pure(List.empty) } yield customExtraction ::: tracing - override def renderClass( + private def renderClass( resourceName: String, handlerName: String, annotations: List[Mod.Annot], @@ -487,7 +591,7 @@ class DropwizardServerGenerator private extends ServerTerms[ScalaLanguage, Targe ) } - override def renderHandler( + private def renderHandler( handlerName: String, methodSigs: List[Decl.Def], handlerDefinitions: List[Stat], diff --git a/modules/scala-http4s/src/main/scala/dev/guardrail/generators/scala/http4s/Http4sServerGenerator.scala b/modules/scala-http4s/src/main/scala/dev/guardrail/generators/scala/http4s/Http4sServerGenerator.scala index 33834b2109..043a00478b 100644 --- a/modules/scala-http4s/src/main/scala/dev/guardrail/generators/scala/http4s/Http4sServerGenerator.scala +++ b/modules/scala-http4s/src/main/scala/dev/guardrail/generators/scala/http4s/Http4sServerGenerator.scala @@ -1,35 +1,54 @@ package dev.guardrail.generators.scala.http4s +import _root_.io.swagger.v3.oas.models.Components +import _root_.io.swagger.v3.oas.models.Operation +import _root_.io.swagger.v3.oas.models.PathItem.HttpMethod import cats.Monad +import cats.Traverse import cats.data.NonEmptyList import cats.implicits._ -import cats.Traverse -import scala.meta._ -import scala.reflect.runtime.universe.typeTag - -import dev.guardrail.{ AuthImplementation, Target, UserError } -import dev.guardrail.terms.protocol.StrictProtocolElems +import dev.guardrail.AuthImplementation +import dev.guardrail.AuthImplementation.Custom +import dev.guardrail.AuthImplementation.Disable +import dev.guardrail.AuthImplementation.Native +import dev.guardrail.AuthImplementation.Simple +import dev.guardrail._ import dev.guardrail.core.Tracker -import dev.guardrail.core.extract.{ ServerRawResponse, TracingLabel } -import dev.guardrail.generators.{ CustomExtractionField, LanguageParameter, LanguageParameters, RenderedRoutes, TracingField } -import dev.guardrail.generators.syntax._ +import dev.guardrail.core.extract.ServerRawResponse +import dev.guardrail.core.extract.TracingLabel +import dev.guardrail.generators.CustomExtractionField +import dev.guardrail.generators.LanguageParameter +import dev.guardrail.generators.LanguageParameters +import dev.guardrail.generators.RenderedRoutes +import dev.guardrail.generators.Server +import dev.guardrail.generators.Servers +import dev.guardrail.generators.TracingField import dev.guardrail.generators.operations.TracingLabelFormatter -import dev.guardrail.generators.scala.syntax._ -import dev.guardrail.generators.scala.{ CirceModelGenerator, ModelGeneratorType, ResponseADTHelper } +import dev.guardrail.generators.scala.CirceModelGenerator +import dev.guardrail.generators.scala.ModelGeneratorType +import dev.guardrail.generators.scala.ResponseADTHelper import dev.guardrail.generators.scala.ScalaLanguage -import dev.guardrail.generators.spi.{ ModuleLoadResult, ServerGeneratorLoader } -import dev.guardrail.terms.{ ContentType, Header, Response, Responses } -import dev.guardrail.terms.server._ +import dev.guardrail.generators.scala.syntax._ +import dev.guardrail.generators.spi.ModuleLoadResult +import dev.guardrail.generators.spi.ServerGeneratorLoader +import dev.guardrail.generators.syntax._ import dev.guardrail.shims._ -import dev.guardrail.terms.{ RouteMeta, SecurityScheme } - -import _root_.io.swagger.v3.oas.models.PathItem.HttpMethod -import _root_.io.swagger.v3.oas.models.Operation +import dev.guardrail.terms.CollectionsLibTerms +import dev.guardrail.terms.ContentType +import dev.guardrail.terms.Header +import dev.guardrail.terms.LanguageTerms +import dev.guardrail.terms.Response +import dev.guardrail.terms.Responses +import dev.guardrail.terms.RouteMeta import dev.guardrail.terms.SecurityRequirements -import dev.guardrail.AuthImplementation.Disable -import dev.guardrail.AuthImplementation.Native -import dev.guardrail.AuthImplementation.Simple -import dev.guardrail.AuthImplementation.Custom +import dev.guardrail.terms.SecurityScheme +import dev.guardrail.terms.SwaggerTerms +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol.StrictProtocolElems +import dev.guardrail.terms.server._ + +import scala.meta._ +import scala.reflect.runtime.universe.typeTag class Http4sServerGeneratorLoader extends ServerGeneratorLoader { type L = ScalaLanguage @@ -60,31 +79,126 @@ object Http4sServerGenerator { } class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms[ScalaLanguage, Target] { - val customExtractionTypeName: Type.Name = Type.Name("E") val authContextTypeName: Type.Name = Type.Name("AuthContext") val authErrorTypeName: Type.Name = Type.Name("AuthError") val authSchemesTypeName: Type.Name = Type.Name("AuthSchemes") val authRequirementTypeName: Type.Name = Type.Name("AuthRequirement") + override def fromSwagger(context: Context, supportPackage: NonEmptyList[String], basePath: Option[String], frameworkImports: List[ScalaLanguage#Import])( + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[ScalaLanguage]], + securitySchemes: Map[String, SecurityScheme[ScalaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Servers[ScalaLanguage]] = { + import Sw._ + import Sc._ + + for { + extraImports <- getExtraImports(context.tracing, supportPackage) + supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) + servers <- groupedRoutes.traverse { case (className, unsortedRoutes) => + val routes = unsortedRoutes + .groupBy(_.path.unwrapTracker.indexOf('{')) + .view + .mapValues(_.sortBy(r => (r.path.unwrapTracker, r.method))) + .toList + .sortBy(_._1) + .flatMap(_._2) + for { + resourceName <- formatTypeName(className.lastOption.getOrElse(""), Some("Resource")) + handlerName <- formatTypeName(className.lastOption.getOrElse(""), Some("Handler")) + + responseServerPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => + for { + operationId <- getOperationId(operation) + responses <- Responses.getResponses(operationId, operation, protocolElems, components) + responseClsName <- formatTypeName(operationId, Some("Response")) + responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) + methodName <- formatMethodName(operationId) + parameters <- route.getParameters[ScalaLanguage, Target](components, protocolElems) + customExtractionField <- buildCustomExtractionFields(operation, className, context.customExtraction) + tracingField <- buildTracingFields(operation, className, context.tracing) + } yield ( + responseDefinitions, + GenerateRouteMeta(operationId, methodName, responseClsName, customExtractionField, tracingField, route, parameters, responses) + ) + } + (responseDefinitions, serverOperations) = responseServerPair.unzip + securityExposure = serverOperations.flatMap(_.routeMeta.securityRequirements) match { + case Nil => SecurityExposure.Undefined + case xs => if (xs.exists(_.optional)) SecurityExposure.Optional else SecurityExposure.Required + } + renderedRoutes <- generateRoutes( + context.tracing, + resourceName, + handlerName, + basePath, + serverOperations, + protocolElems, + securitySchemes, + securityExposure, + context.authImplementation + ) + handlerSrc <- renderHandler( + handlerName, + renderedRoutes.methodSigs, + renderedRoutes.handlerDefinitions, + responseDefinitions.flatten, + context.customExtraction, + context.authImplementation, + securityExposure + ) + extraRouteParams <- getExtraRouteParams( + resourceName, + context.customExtraction, + context.tracing, + context.authImplementation, + securityExposure + ) + classSrc <- renderClass( + resourceName, + handlerName, + renderedRoutes.classAnnotations, + renderedRoutes.routes, + extraRouteParams, + responseDefinitions.flatten, + renderedRoutes.supportDefinitions, + renderedRoutes.securitySchemesDefinitions, + context.customExtraction, + context.authImplementation + ) + } yield Server[ScalaLanguage](className, frameworkImports ++ extraImports, handlerSrc, classSrc) + } + } yield Servers[ScalaLanguage](servers, supportDefinitions) + } + private val bodyUtf8Decode = version match { case Http4sVersion.V0_22 => q"utf8Decode" case Http4sVersion.V0_23 => q"utf8.decode" } - def splitOperationParts(operationId: String): (List[String], String) = { + private def splitOperationParts(operationId: String): (List[String], String) = { val parts = operationId.split('.') (parts.drop(1).toList, parts.last) } + implicit def MonadF: Monad[Target] = Target.targetInstances - def generateResponseDefinitions( + + private def generateResponseDefinitions( responseClsName: String, responses: Responses[ScalaLanguage], protocolElems: List[StrictProtocolElems[ScalaLanguage]] ) = Target.pure(ResponseADTHelper.generateResponseDefinitions(responseClsName, responses, protocolElems)) - def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean) = + private def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean) = for { _ <- Target.log.debug(s"buildCustomExtractionFields(${operation.unwrapTracker.showNotNull}, ${resourceName}, ${customExtraction})") res <- @@ -103,7 +217,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } else Target.pure(None) } yield res - def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean) = + private def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean) = Target.log.function("buildTracingFields")(for { _ <- Target.log.debug(s"Args: ${operation}, ${resourceName}, ${tracing}") res <- @@ -123,7 +237,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } else Target.pure(None) } yield res) - override def generateRoutes( + private def generateRoutes( tracing: Boolean, resourceName: String, handlerName: String, @@ -247,7 +361,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms private def securitySchemeNameToClassName(name: String): Term.Name = Term.Name(name.toPascalCase) - override def renderHandler( + private def renderHandler( handlerName: String, methodSigs: List[scala.meta.Decl.Def], handlerDefinitions: List[scala.meta.Stat], @@ -270,7 +384,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } """) - def getExtraRouteParams( + private def getExtraRouteParams( resourceName: String, customExtraction: Boolean, tracing: Boolean, @@ -314,13 +428,13 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } } yield customExtraction_.toList ::: tracing_.toList ::: authentication_.toList ::: List(mapRoute)) - def generateSupportDefinitions( + private def generateSupportDefinitions( tracing: Boolean, securitySchemes: Map[String, SecurityScheme[ScalaLanguage]] ) = Target.pure(List.empty) - def renderClass( + private def renderClass( resourceName: String, handlerName: String, annotations: List[scala.meta.Mod.Annot], @@ -371,7 +485,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms ) ) - def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]) = + private def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]) = Target.log.function("getExtraImports")( for { _ <- Target.log.debug(s"Args: ${tracing}") @@ -382,7 +496,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms ) ) - def httpMethodToHttp4s(method: HttpMethod): Target[Term.Name] = method match { + private def httpMethodToHttp4s(method: HttpMethod): Target[Term.Name] = method match { case HttpMethod.DELETE => Target.pure(Term.Name("DELETE")) case HttpMethod.GET => Target.pure(Term.Name("GET")) case HttpMethod.PATCH => Target.pure(Term.Name("PATCH")) @@ -392,13 +506,13 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms case other => Target.raiseUserError(s"Unknown method: ${other}") } - def pathStrToHttp4s(basePath: Option[String], path: Tracker[String], pathArgs: List[LanguageParameter[ScalaLanguage]]): Target[(Pat, Option[Pat])] = + private def pathStrToHttp4s(basePath: Option[String], path: Tracker[String], pathArgs: List[LanguageParameter[ScalaLanguage]]): Target[(Pat, Option[Pat])] = (basePath.getOrElse("") + path.unwrapTracker).stripPrefix("/") match { case "" => Target.pure((p"${Term.Name("Root")}", None)) case finalPath => Http4sServerGenerator.generateUrlPathExtractors(Tracker.cloneHistory(path, finalPath), pathArgs, CirceModelGenerator.V012) } - def directivesFromParams[T]( + private def directivesFromParams[T]( required: LanguageParameter[ScalaLanguage] => Type => Target[T], multi: LanguageParameter[ScalaLanguage] => Type => (Term => Term) => Target[T], multiOpt: LanguageParameter[ScalaLanguage] => Type => (Term => Term) => Target[T], @@ -434,7 +548,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } } yield directives - def bodyToHttp4s(methodName: String, body: Option[LanguageParameter[ScalaLanguage]]): Target[Option[Term => Term]] = + private def bodyToHttp4s(methodName: String, body: Option[LanguageParameter[ScalaLanguage]]): Target[Option[Term => Term]] = Target.pure( body.map { case LanguageParameter(_, _, paramName, _, _) => content => q""" @@ -456,9 +570,9 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } ) - case class Param(generator: Option[Enumerator.Generator], matcher: Option[(Term, Pat)], handlerCallArg: Term) + private case class Param(generator: Option[Enumerator.Generator], matcher: Option[(Term, Pat)], handlerCallArg: Term) - def headersToHttp4s: List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] = + private def headersToHttp4s: List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] = directivesFromParams( arg => { case t"String" => @@ -499,7 +613,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } ) - def qsToHttp4s(methodName: String): List[LanguageParameter[ScalaLanguage]] => Target[Option[Pat]] = + private def qsToHttp4s(methodName: String): List[LanguageParameter[ScalaLanguage]] => Target[Option[Pat]] = params => directivesFromParams( arg => _ => Target.pure(p"${Term.Name(s"${methodName.capitalize}${arg.argName.value.capitalize}Matcher")}(${Pat.Var(arg.paramName)})"), @@ -512,7 +626,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms Some(xs.foldLeft[Pat](x) { case (a, n) => p"${a} +& ${n}" }) } - def formToHttp4s: List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] = + private def formToHttp4s: List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] = directivesFromParams( arg => { case t"String" => @@ -587,7 +701,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } ) - def asyncFormToHttp4s(methodName: String): List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] = + private def asyncFormToHttp4s(methodName: String): List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] = directivesFromParams( arg => elemType => @@ -717,7 +831,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } ) - case class RenderedRoute(methodName: String, route: Case, methodSig: Decl.Def, supportDefinitions: List[Defn], handlerDefinitions: List[Stat]) + private case class RenderedRoute(methodName: String, route: Case, methodSig: Decl.Def, supportDefinitions: List[Defn], handlerDefinitions: List[Stat]) private def generateRoute( resourceName: String, @@ -999,7 +1113,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms q"""_root_.cats.data.NonEmptyList.of(..$orElements)""" } - def createHttp4sHeaders(headers: List[Header[ScalaLanguage]]): (Term.Name, List[Defn.Val]) = { + private def createHttp4sHeaders(headers: List[Header[ScalaLanguage]]): (Term.Name, List[Defn.Val]) = { val (names, definitions) = headers.map { case Header(name, required, _, termName) => val nameLiteral = Lit.String(name) val headerName = Term.Name(s"${name.toCamelCase}Header") @@ -1012,21 +1126,21 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms (headersTerm, definitions :+ allHeaders) } - def combineRouteTerms(terms: List[Case]): Target[Term] = + private def combineRouteTerms(terms: List[Case]): Target[Term] = Target.log.function("combineRouteTerms")(for { _ <- Target.log.debug(s"Args: <${terms.length} routes>") routes <- Target.fromOption(NonEmptyList.fromList(terms), UserError("Generated no routes, no source to generate")) _ <- routes.traverse(route => Target.log.debug(route.toString)) } yield scala.meta.Term.PartialFunction(routes.toList)) - def generateSupportDefinitions(route: RouteMeta, parameters: LanguageParameters[ScalaLanguage]): Target[List[Defn]] = + private def generateSupportDefinitions(route: RouteMeta, parameters: LanguageParameters[ScalaLanguage]): Target[List[Defn]] = for { operation <- Target.pure(route.operation) pathArgs = parameters.pathParams } yield generatePathParamExtractors(pathArgs) - def generatePathParamExtractors(pathArgs: List[LanguageParameter[ScalaLanguage]]): List[Defn] = + private def generatePathParamExtractors(pathArgs: List[LanguageParameter[ScalaLanguage]]): List[Defn] = pathArgs .map(_.argType) .flatMap { @@ -1055,7 +1169,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms """ } - def modifiedOptionalMultiQueryParamDecoderMatcher(matcherName: Term.Name, container: Type, argName: Lit.String, tpe: Type, transform: Term => Term) = + private def modifiedOptionalMultiQueryParamDecoderMatcher(matcherName: Term.Name, container: Type, argName: Lit.String, tpe: Type, transform: Term => Term) = q""" object ${matcherName} { def unapply(params: Map[String, collection.Seq[String]]): Option[Option[$container[$tpe]]] = { @@ -1069,7 +1183,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } """ - def modifiedQueryParamDecoderMatcher(matcherName: Term.Name, container: Type, argName: Lit.String, tpe: Type, transform: Term => Term) = + private def modifiedQueryParamDecoderMatcher(matcherName: Term.Name, container: Type, argName: Lit.String, tpe: Type, transform: Term => Term) = q""" object ${matcherName} { def unapply(params: Map[String, collection.Seq[String]]): Option[${container}[${tpe}]] = { @@ -1082,7 +1196,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms } """ - def generateQueryParamMatchers(methodName: String, qsArgs: List[LanguageParameter[ScalaLanguage]]): List[Defn] = { + private def generateQueryParamMatchers(methodName: String, qsArgs: List[LanguageParameter[ScalaLanguage]]): List[Defn] = { val (decoders, matchers) = qsArgs .traverse { case LanguageParameter(_, param, _, argName, argType) => val containerTransformations = Map[String, Term => Term]( @@ -1141,7 +1255,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms /** It's not possible to use backticks inside pattern matching as it has different semantics: backticks inside match are just references to an already * existing bindings. */ - def prepareParameters[F[_]: Traverse](parameters: F[LanguageParameter[ScalaLanguage]]): Target[F[LanguageParameter[ScalaLanguage]]] = + private def prepareParameters[F[_]: Traverse](parameters: F[LanguageParameter[ScalaLanguage]]): Target[F[LanguageParameter[ScalaLanguage]]] = if (parameters.exists(param => param.paramName.syntax != param.paramName.value)) { // let's try to prefix them all with underscore and see if it helps for { @@ -1160,7 +1274,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms Target.pure(parameters) } - def generateCodecs( + private def generateCodecs( methodName: String, bodyArgs: Option[LanguageParameter[ScalaLanguage]], responses: Responses[ScalaLanguage], @@ -1172,7 +1286,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms responses ) - def generateDecoders(methodName: String, bodyArgs: Option[LanguageParameter[ScalaLanguage]], consumes: Seq[ContentType]): List[Defn.Val] = + private def generateDecoders(methodName: String, bodyArgs: Option[LanguageParameter[ScalaLanguage]], consumes: Seq[ContentType]): List[Defn.Val] = bodyArgs.toList.flatMap { case LanguageParameter(_, _, _, _, argType) => List( q"protected[this] val ${Pat.Typed(Pat.Var(Term.Name(s"${methodName.uncapitalized}Decoder")), t"EntityDecoder[F, $argType]")} = ${ResponseADTHelper @@ -1180,7 +1294,7 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms ) } - def generateEncoders(methodName: String, responses: Responses[ScalaLanguage], produces: Seq[ContentType]): List[Defn.Val] = + private def generateEncoders(methodName: String, responses: Responses[ScalaLanguage], produces: Seq[ContentType]): List[Defn.Val] = for { response <- responses.value (_, tpe, _) <- response.value @@ -1189,21 +1303,21 @@ class Http4sServerGenerator private (version: Http4sVersion) extends ServerTerms q"protected[this] val ${Pat.Var(Term.Name(s"$methodName${response.statusCodeName}Encoder"))} = ${ResponseADTHelper.generateEncoder(tpe, contentTypes)}" } - def generateResponseGenerators(methodName: String, responses: Responses[ScalaLanguage]): List[Defn.Val] = + private def generateResponseGenerators(methodName: String, responses: Responses[ScalaLanguage]): List[Defn.Val] = for { response <- responses.value if response.value.nonEmpty } yield q"protected[this] val ${Pat.Var(Term.Name(s"$methodName${response.statusCodeName}EntityResponseGenerator"))} = ${ResponseADTHelper .generateEntityResponseGenerator(q"org.http4s.Status.${response.statusCodeName}")}" - def generateTracingExtractor(methodName: String, tracingField: Term): Defn.Object = + private def generateTracingExtractor(methodName: String, tracingField: Term): Defn.Object = q""" object ${Term.Name(s"usingFor${methodName.capitalize}")} { def unapply(r: Request[F]): Some[(Request[F], TraceBuilder[F])] = Some(r -> $tracingField(r)) } """ - def generateCustomExtractionFieldsExtractor(methodName: String, extractField: Term): Defn.Object = + private def generateCustomExtractionFieldsExtractor(methodName: String, extractField: Term): Defn.Object = q""" object ${Term.Name(s"extractorFor${methodName.capitalize}")} { def unapply(r: Request[F]): Some[(Request[F], $customExtractionTypeName)] = Some(r -> $extractField(r)) From 1e78560e7ce5a5c40e31d6ee669820a609beec5a Mon Sep 17 00:00:00 2001 From: Aleksei Lezhoev Date: Sun, 30 Jul 2023 11:55:11 +0200 Subject: [PATCH 5/6] Degeneralize client generator --- .../src/main/scala/dev/guardrail/Common.scala | 5 +- .../generators/ClientGenerator.scala | 82 --------- .../dev/guardrail/generators/Clients.scala | 21 +++ .../guardrail/terms/client/ClientTerms.scala | 136 ++------------- .../AsyncHttpClientClientGenerator.scala | 156 ++++++++++++++---- .../springMvc/SpringMvcClientGenerator.scala | 78 +++------ .../akkaHttp/AkkaHttpClientGenerator.scala | 124 +++++++++++--- .../DropwizardClientGenerator.scala | 75 ++++----- .../scala/http4s/Http4sClientGenerator.scala | 99 +++++++++-- 9 files changed, 393 insertions(+), 383 deletions(-) delete mode 100644 modules/core/src/main/scala/dev/guardrail/generators/ClientGenerator.scala create mode 100644 modules/core/src/main/scala/dev/guardrail/generators/Clients.scala diff --git a/modules/core/src/main/scala/dev/guardrail/Common.scala b/modules/core/src/main/scala/dev/guardrail/Common.scala index eae1104094..615860d9b9 100644 --- a/modules/core/src/main/scala/dev/guardrail/Common.scala +++ b/modules/core/src/main/scala/dev/guardrail/Common.scala @@ -8,7 +8,7 @@ import java.nio.file.Path import java.net.URI import dev.guardrail.core.{ SupportDefinition, Tracker } -import dev.guardrail.generators.{ ClientGenerator, Clients, ProtocolDefinitions, ProtocolGenerator, Servers } +import dev.guardrail.generators.{ Clients, ProtocolDefinitions, ProtocolGenerator, Servers } import dev.guardrail.languages.LA import dev.guardrail.terms.client.ClientTerms import dev.guardrail.terms.framework.FrameworkTerms @@ -87,8 +87,7 @@ object Common { codegen <- kind match { case CodegenTarget.Client => for { - clientMeta <- ClientGenerator - .fromSwagger[L, F](context, frameworkImports)(serverUrls, basePath, groupedRoutes)(protocolElems, securitySchemes, components) + clientMeta <- C.fromSwagger(context, frameworkImports)(serverUrls, basePath, groupedRoutes)(protocolElems, securitySchemes, components) Clients(clients, supportDefinitions) = clientMeta frameworkImplicits <- getFrameworkImplicits() } yield CodegenDefinitions[L](clients, List.empty, supportDefinitions, frameworkImplicits) diff --git a/modules/core/src/main/scala/dev/guardrail/generators/ClientGenerator.scala b/modules/core/src/main/scala/dev/guardrail/generators/ClientGenerator.scala deleted file mode 100644 index 3f1a997334..0000000000 --- a/modules/core/src/main/scala/dev/guardrail/generators/ClientGenerator.scala +++ /dev/null @@ -1,82 +0,0 @@ -package dev.guardrail.generators - -import cats.data.NonEmptyList -import cats.implicits._ -import java.net.URI - -import dev.guardrail.core.SupportDefinition -import dev.guardrail.languages.LA -import dev.guardrail.terms.Responses -import dev.guardrail.terms._ -import dev.guardrail.terms.client.ClientTerms -import dev.guardrail.terms.framework.FrameworkTerms -import dev.guardrail.terms.protocol.{ StaticDefns, StrictProtocolElems } -import dev.guardrail.{ Context, _ } -import dev.guardrail.core.Tracker -import io.swagger.v3.oas.models.Components - -case class Clients[L <: LA](clients: List[Client[L]], supportDefinitions: List[SupportDefinition[L]]) -case class Client[L <: LA]( - pkg: List[String], - clientName: String, - imports: List[L#Import], - staticDefns: StaticDefns[L], - client: NonEmptyList[Either[L#Trait, L#ClassDefinition]], - responseDefinitions: List[L#Definition] -) -case class RenderedClientOperation[L <: LA]( - clientOperation: L#Definition, - supportDefinitions: List[L#Definition] -) - -object ClientGenerator { - def fromSwagger[L <: LA, F[_]](context: Context, frameworkImports: List[L#Import])( - serverUrls: Option[NonEmptyList[URI]], - basePath: Option[String], - groupedRoutes: List[(List[String], List[RouteMeta])] - )( - protocolElems: List[StrictProtocolElems[L]], - securitySchemes: Map[String, SecurityScheme[L]], - components: Tracker[Option[Components]] - )(implicit C: ClientTerms[L, F], Fw: FrameworkTerms[L, F], Sc: LanguageTerms[L, F], Cl: CollectionsLibTerms[L, F], Sw: SwaggerTerms[L, F]): F[Clients[L]] = { - import C._ - import Sc._ - import Sw._ - for { - clientImports <- getImports(context.tracing) - clientExtraImports <- getExtraImports(context.tracing) - supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) - clients <- groupedRoutes.traverse { case (className, unsortedRoutes) => - val routes = unsortedRoutes.sortBy(r => (r.path.unwrapTracker, r.method)) - for { - clientName <- formatTypeName(className.lastOption.getOrElse(""), Some("Client")) - responseClientPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => - for { - operationId <- getOperationId(operation) - responses <- Responses.getResponses[L, F](operationId, operation, protocolElems, components) - responseClsName <- formatTypeName(operationId, Some("Response")) - responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) - parameters <- route.getParameters[L, F](components, protocolElems) - methodName <- formatMethodName(operationId) - clientOp <- generateClientOperation(className, responseClsName, context.tracing, securitySchemes, parameters)(route, methodName, responses) - } yield (responseDefinitions, clientOp) - } - (responseDefinitions, clientOperations) = responseClientPair.unzip - tracingName = Option(className.mkString("-")).filterNot(_.isEmpty) - ctorArgs <- clientClsArgs(tracingName, serverUrls, context.tracing) - staticDefns <- buildStaticDefns(clientName, tracingName, serverUrls, ctorArgs, context.tracing) - client <- buildClient( - clientName, - tracingName, - serverUrls, - basePath, - ctorArgs, - clientOperations.map(_.clientOperation), - clientOperations.flatMap(_.supportDefinitions), - context.tracing - ) - } yield Client[L](className, clientName, clientImports ++ frameworkImports ++ clientExtraImports, staticDefns, client, responseDefinitions.flatten) - } - } yield Clients[L](clients, supportDefinitions) - } -} diff --git a/modules/core/src/main/scala/dev/guardrail/generators/Clients.scala b/modules/core/src/main/scala/dev/guardrail/generators/Clients.scala new file mode 100644 index 0000000000..ac9d313084 --- /dev/null +++ b/modules/core/src/main/scala/dev/guardrail/generators/Clients.scala @@ -0,0 +1,21 @@ +package dev.guardrail.generators + +import cats.data.NonEmptyList + +import dev.guardrail.core.SupportDefinition +import dev.guardrail.languages.LA +import dev.guardrail.terms.protocol.StaticDefns + +case class Clients[L <: LA](clients: List[Client[L]], supportDefinitions: List[SupportDefinition[L]]) +case class Client[L <: LA]( + pkg: List[String], + clientName: String, + imports: List[L#Import], + staticDefns: StaticDefns[L], + client: NonEmptyList[Either[L#Trait, L#ClassDefinition]], + responseDefinitions: List[L#Definition] +) +case class RenderedClientOperation[L <: LA]( + clientOperation: L#Definition, + supportDefinitions: List[L#Definition] +) diff --git a/modules/core/src/main/scala/dev/guardrail/terms/client/ClientTerms.scala b/modules/core/src/main/scala/dev/guardrail/terms/client/ClientTerms.scala index a48d70578a..e7862f30b3 100644 --- a/modules/core/src/main/scala/dev/guardrail/terms/client/ClientTerms.scala +++ b/modules/core/src/main/scala/dev/guardrail/terms/client/ClientTerms.scala @@ -2,129 +2,27 @@ package dev.guardrail.terms.client import cats.Monad import cats.data.NonEmptyList -import java.net.URI - -import dev.guardrail.core.SupportDefinition -import dev.guardrail.generators.LanguageParameters -import dev.guardrail.generators.RenderedClientOperation +import dev.guardrail.Context +import dev.guardrail.core.Tracker +import dev.guardrail.generators.Clients import dev.guardrail.languages.LA -import dev.guardrail.terms.Responses -import dev.guardrail.terms.protocol.{ StaticDefns, StrictProtocolElems } -import dev.guardrail.terms.{ RouteMeta, SecurityScheme } +import dev.guardrail.terms._ +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol.StrictProtocolElems +import io.swagger.v3.oas.models.Components + +import java.net.URI abstract class ClientTerms[L <: LA, F[_]] { self => def MonadF: Monad[F] - def generateClientOperation( - className: List[String], - responseClsName: String, - tracing: Boolean, - securitySchemes: Map[String, SecurityScheme[L]], - parameters: LanguageParameters[L] - )( - route: RouteMeta, - methodName: String, - responses: Responses[L] - ): F[RenderedClientOperation[L]] - def getImports(tracing: Boolean): F[List[L#Import]] - def getExtraImports(tracing: Boolean): F[List[L#Import]] - def clientClsArgs(tracingName: Option[String], serverUrls: Option[NonEmptyList[URI]], tracing: Boolean): F[List[List[L#MethodParameter]]] - def generateResponseDefinitions(responseClsName: String, responses: Responses[L], protocolElems: List[StrictProtocolElems[L]]): F[List[L#Definition]] - def generateSupportDefinitions(tracing: Boolean, securitySchemes: Map[String, SecurityScheme[L]]): F[List[SupportDefinition[L]]] - def buildStaticDefns( - clientName: String, - tracingName: Option[String], - serverUrls: Option[NonEmptyList[URI]], - ctorArgs: List[List[L#MethodParameter]], - tracing: Boolean - ): F[StaticDefns[L]] - def buildClient( - clientName: String, - tracingName: Option[String], + + def fromSwagger(context: Context, frameworkImports: List[L#Import])( serverUrls: Option[NonEmptyList[URI]], basePath: Option[String], - ctorArgs: List[List[L#MethodParameter]], - clientCalls: List[L#Definition], - supportDefinitions: List[L#Definition], - tracing: Boolean - ): F[NonEmptyList[Either[L#Trait, L#ClassDefinition]]] - - def copy( - MonadF: Monad[F] = self.MonadF, - generateClientOperation: (List[String], String, Boolean, Map[String, SecurityScheme[L]], LanguageParameters[L]) => ( - RouteMeta, - String, - Responses[L] - ) => F[RenderedClientOperation[L]] = self.generateClientOperation _, - getImports: Boolean => F[List[L#Import]] = self.getImports _, - getExtraImports: Boolean => F[List[L#Import]] = self.getExtraImports _, - clientClsArgs: (Option[String], Option[NonEmptyList[URI]], Boolean) => F[List[List[L#MethodParameter]]] = self.clientClsArgs _, - generateResponseDefinitions: (String, Responses[L], List[StrictProtocolElems[L]]) => F[List[L#Definition]] = self.generateResponseDefinitions _, - generateSupportDefinitions: (Boolean, Map[String, SecurityScheme[L]]) => F[List[SupportDefinition[L]]] = self.generateSupportDefinitions _, - buildStaticDefns: (String, Option[String], Option[NonEmptyList[URI]], List[List[L#MethodParameter]], Boolean) => F[StaticDefns[L]] = - self.buildStaticDefns _, - buildClient: ( - String, - Option[String], - Option[NonEmptyList[URI]], - Option[String], - List[List[L#MethodParameter]], - List[L#Definition], - List[L#Definition], - Boolean - ) => F[NonEmptyList[Either[L#Trait, L#ClassDefinition]]] = self.buildClient _ - ): ClientTerms[L, F] = { - val newMonadF = MonadF - val newGenerateClientOperation = generateClientOperation - val newGetImports = getImports - val newGetExtraImports = getExtraImports - val newClientClsArgs = clientClsArgs - val newGenerateResponseDefinitions = generateResponseDefinitions - val newGenerateSupportDefinitions = generateSupportDefinitions - val newBuildStaticDefns = buildStaticDefns - val newBuildClient = buildClient - - new ClientTerms[L, F] { - def MonadF = newMonadF - def generateClientOperation( - className: List[String], - responseClsName: String, - tracing: Boolean, - securitySchemes: Map[String, SecurityScheme[L]], - parameters: LanguageParameters[L] - )( - route: RouteMeta, - methodName: String, - responses: Responses[L] - ) = - newGenerateClientOperation(className, responseClsName, tracing, securitySchemes, parameters)(route, methodName, responses) - def getImports(tracing: Boolean) = newGetImports(tracing) - def getExtraImports(tracing: Boolean) = newGetExtraImports(tracing) - def clientClsArgs(tracingName: Option[String], serverUrls: Option[NonEmptyList[URI]], tracing: Boolean) = - newClientClsArgs(tracingName, serverUrls, tracing) - def generateResponseDefinitions(responseClsName: String, responses: Responses[L], protocolElems: List[StrictProtocolElems[L]]) = - newGenerateResponseDefinitions(responseClsName, responses, protocolElems) - def generateSupportDefinitions(tracing: Boolean, securitySchemes: Map[String, SecurityScheme[L]]): F[List[SupportDefinition[L]]] = - newGenerateSupportDefinitions(tracing, securitySchemes) - - def buildStaticDefns( - clientName: String, - tracingName: Option[String], - serverUrls: Option[NonEmptyList[URI]], - ctorArgs: List[List[L#MethodParameter]], - tracing: Boolean - ): F[StaticDefns[L]] = - newBuildStaticDefns(clientName, tracingName, serverUrls, ctorArgs, tracing) - def buildClient( - clientName: String, - tracingName: Option[String], - serverUrls: Option[NonEmptyList[URI]], - basePath: Option[String], - ctorArgs: List[List[L#MethodParameter]], - clientCalls: List[L#Definition], - supportDefinitions: List[L#Definition], - tracing: Boolean - ) = - newBuildClient(clientName, tracingName, serverUrls, basePath, ctorArgs, clientCalls, supportDefinitions, tracing) - } - } + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[L]], + securitySchemes: Map[String, SecurityScheme[L]], + components: Tracker[Option[Components]] + )(implicit Fw: FrameworkTerms[L, F], Sc: LanguageTerms[L, F], Cl: CollectionsLibTerms[L, F], Sw: SwaggerTerms[L, F]): F[Clients[L]] } diff --git a/modules/java-async-http/src/main/scala/dev/guardrail/generators/java/asyncHttpClient/AsyncHttpClientClientGenerator.scala b/modules/java-async-http/src/main/scala/dev/guardrail/generators/java/asyncHttpClient/AsyncHttpClientClientGenerator.scala index d8c73297ef..7b5530f224 100644 --- a/modules/java-async-http/src/main/scala/dev/guardrail/generators/java/asyncHttpClient/AsyncHttpClientClientGenerator.scala +++ b/modules/java-async-http/src/main/scala/dev/guardrail/generators/java/asyncHttpClient/AsyncHttpClientClientGenerator.scala @@ -1,22 +1,32 @@ package dev.guardrail.generators.java.asyncHttpClient +import _root_.io.swagger.v3.oas.models.Components import cats.Monad import cats.data.NonEmptyList import cats.syntax.all._ import com.github.javaparser.StaticJavaParser +import com.github.javaparser.ast.ImportDeclaration import com.github.javaparser.ast.Modifier.Keyword._ import com.github.javaparser.ast.Modifier._ -import com.github.javaparser.ast.`type`.{ ClassOrInterfaceType, Type, UnknownType, VoidType } +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.`type`.ClassOrInterfaceType +import com.github.javaparser.ast.`type`.Type +import com.github.javaparser.ast.`type`.UnknownType +import com.github.javaparser.ast.`type`.VoidType import com.github.javaparser.ast.body._ -import com.github.javaparser.ast.expr.{ MethodCallExpr, NameExpr, _ } +import com.github.javaparser.ast.expr.MethodCallExpr +import com.github.javaparser.ast.expr.NameExpr +import com.github.javaparser.ast.expr._ import com.github.javaparser.ast.stmt._ -import com.github.javaparser.ast.{ ImportDeclaration, NodeList } -import java.net.URI -import java.util.concurrent.CompletionStage -import scala.annotation.tailrec - +import dev.guardrail.Context import dev.guardrail.Target -import dev.guardrail.core.{ PathExtractor, SupportDefinition } +import dev.guardrail.core.PathExtractor +import dev.guardrail.core.SupportDefinition +import dev.guardrail.core.Tracker +import dev.guardrail.generators.Client +import dev.guardrail.generators.Clients +import dev.guardrail.generators.LanguageParameter +import dev.guardrail.generators.LanguageParameters import dev.guardrail.generators.RenderedClientOperation import dev.guardrail.generators.java.JavaCollectionsGenerator import dev.guardrail.generators.java.JavaGenerator @@ -24,28 +34,37 @@ import dev.guardrail.generators.java.JavaLanguage import dev.guardrail.generators.java.JavaVavrCollectionsGenerator import dev.guardrail.generators.java.asyncHttpClient.AsyncHttpClientHelpers._ import dev.guardrail.generators.java.syntax._ -import dev.guardrail.generators.spi.{ ClientGeneratorLoader, CollectionsGeneratorLoader, ModuleLoadResult } -import dev.guardrail.generators.{ LanguageParameter, LanguageParameters } +import dev.guardrail.generators.spi.ClientGeneratorLoader +import dev.guardrail.generators.spi.CollectionsGeneratorLoader +import dev.guardrail.generators.spi.ModuleLoadResult import dev.guardrail.javaext.helpers.ResponseHelpers -import scala.reflect.runtime.universe.typeTag import dev.guardrail.shims._ +import dev.guardrail.terms.AnyContentType +import dev.guardrail.terms.ApplicationJson +import dev.guardrail.terms.BinaryContent +import dev.guardrail.terms.CollectionsLibTerms +import dev.guardrail.terms.ContentType +import dev.guardrail.terms.LanguageTerms +import dev.guardrail.terms.OctetStream +import dev.guardrail.terms.Response +import dev.guardrail.terms.Responses +import dev.guardrail.terms.RouteMeta +import dev.guardrail.terms.SecurityScheme +import dev.guardrail.terms.SwaggerTerms +import dev.guardrail.terms.TextContent +import dev.guardrail.terms.TextPlain import dev.guardrail.terms.client.ClientTerms -import dev.guardrail.terms.collections.{ CollectionsAbstraction, JavaStdLibCollections, JavaVavrCollections } -import dev.guardrail.terms.protocol.{ StaticDefns, StrictProtocolElems } -import dev.guardrail.terms.{ - AnyContentType, - ApplicationJson, - BinaryContent, - CollectionsLibTerms, - ContentType, - OctetStream, - Response, - Responses, - RouteMeta, - SecurityScheme, - TextContent, - TextPlain -} +import dev.guardrail.terms.collections.CollectionsAbstraction +import dev.guardrail.terms.collections.JavaStdLibCollections +import dev.guardrail.terms.collections.JavaVavrCollections +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol.StaticDefns +import dev.guardrail.terms.protocol.StrictProtocolElems + +import java.net.URI +import java.util.concurrent.CompletionStage +import scala.annotation.tailrec +import scala.reflect.runtime.universe.typeTag class AsyncHttpClientClientGeneratorLoader extends ClientGeneratorLoader { type L = JavaLanguage @@ -61,6 +80,76 @@ class AsyncHttpClientClientGeneratorLoader extends ClientGeneratorLoader { object AsyncHttpClientClientGenerator { def apply()(implicit Cl: CollectionsLibTerms[JavaLanguage, Target], Ca: CollectionsAbstraction[JavaLanguage]): ClientTerms[JavaLanguage, Target] = new AsyncHttpClientClientGenerator +} + +@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements", "org.wartremover.warts.Null")) +class AsyncHttpClientClientGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, Target], Ca: CollectionsAbstraction[JavaLanguage]) + extends ClientTerms[JavaLanguage, Target] { + + override implicit def MonadF: Monad[Target] = Target.targetInstances + + override def fromSwagger(context: Context, frameworkImports: List[JavaLanguage#Import])( + serverUrls: Option[NonEmptyList[URI]], + basePath: Option[String], + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[JavaLanguage]], + securitySchemes: Map[String, SecurityScheme[JavaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[Clients[JavaLanguage]] = { + import Sc._ + import Sw._ + import dev.guardrail._ + + for { + clientImports <- getImports(context.tracing) + clientExtraImports <- getExtraImports(context.tracing) + supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) + clients <- groupedRoutes.traverse { case (className, unsortedRoutes) => + val routes = unsortedRoutes.sortBy(r => (r.path.unwrapTracker, r.method)) + for { + clientName <- formatTypeName(className.lastOption.getOrElse(""), Some("Client")) + responseClientPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => + for { + operationId <- getOperationId(operation) + responses <- Responses.getResponses[JavaLanguage, Target](operationId, operation, protocolElems, components) + responseClsName <- formatTypeName(operationId, Some("Response")) + responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) + parameters <- route.getParameters[JavaLanguage, Target](components, protocolElems) + methodName <- formatMethodName(operationId) + clientOp <- generateClientOperation(className, responseClsName, context.tracing, securitySchemes, parameters)(route, methodName, responses) + } yield (responseDefinitions, clientOp) + } + (responseDefinitions, clientOperations) = responseClientPair.unzip + tracingName = Option(className.mkString("-")).filterNot(_.isEmpty) + ctorArgs <- clientClsArgs(tracingName, serverUrls, context.tracing) + staticDefns <- buildStaticDefns(clientName, tracingName, serverUrls, ctorArgs, context.tracing) + client <- buildClient( + clientName, + tracingName, + serverUrls, + basePath, + ctorArgs, + clientOperations.map(_.clientOperation), + clientOperations.flatMap(_.supportDefinitions), + context.tracing + ) + } yield Client[JavaLanguage]( + className, + clientName, + clientImports ++ frameworkImports ++ clientExtraImports, + staticDefns, + client, + responseDefinitions.flatten + ) + } + } yield Clients[JavaLanguage](clients, supportDefinitions) + } private val URI_TYPE = StaticJavaParser.parseClassOrInterfaceType("URI") private val OBJECT_MAPPER_TYPE = StaticJavaParser.parseClassOrInterfaceType("ObjectMapper") @@ -450,15 +539,6 @@ object AsyncHttpClientClientGenerator { (imports, cls) } -} - -@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements", "org.wartremover.warts.Null")) -class AsyncHttpClientClientGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, Target], Ca: CollectionsAbstraction[JavaLanguage]) - extends ClientTerms[JavaLanguage, Target] { - import AsyncHttpClientClientGenerator._ - import Ca._ - - implicit def MonadF: Monad[Target] = Target.targetInstances def generateClientOperation( className: List[String], @@ -471,6 +551,8 @@ class AsyncHttpClientClientGenerator private (implicit Cl: CollectionsLibTerms[J methodName: String, responses: Responses[JavaLanguage] ): Target[RenderedClientOperation[JavaLanguage]] = { + import Ca._ + val RouteMeta(pathStr, httpMethod, operation, securityRequirements) = route for { @@ -1064,6 +1146,8 @@ class AsyncHttpClientClientGenerator private (implicit Cl: CollectionsLibTerms[J com.github.javaparser.ast.body.ClassOrInterfaceDeclaration, com.github.javaparser.ast.body.TypeDeclaration[_ <: com.github.javaparser.ast.body.TypeDeclaration[_]] ]]] = { + import Ca._ + def createSetter(tpe: Type, name: String, initializer: String => Target[Expression]): Target[MethodDeclaration] = for { initializerExpr <- initializer(name) diff --git a/modules/java-spring-mvc/src/main/scala/dev/guardrail/generators/java/springMvc/SpringMvcClientGenerator.scala b/modules/java-spring-mvc/src/main/scala/dev/guardrail/generators/java/springMvc/SpringMvcClientGenerator.scala index 94c171d5e0..0fcdf398b3 100644 --- a/modules/java-spring-mvc/src/main/scala/dev/guardrail/generators/java/springMvc/SpringMvcClientGenerator.scala +++ b/modules/java-spring-mvc/src/main/scala/dev/guardrail/generators/java/springMvc/SpringMvcClientGenerator.scala @@ -1,15 +1,17 @@ package dev.guardrail.generators.java.springMvc +import _root_.io.swagger.v3.oas.models.Components import cats.data.NonEmptyList -import java.net.URI - -import dev.guardrail.Target -import dev.guardrail.generators.LanguageParameters +import dev.guardrail._ +import dev.guardrail.core.Tracker +import dev.guardrail.generators.Clients import dev.guardrail.generators.java.JavaLanguage -import dev.guardrail.terms.Responses +import dev.guardrail.terms._ import dev.guardrail.terms.client.ClientTerms +import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.StrictProtocolElems -import dev.guardrail.terms.{ CollectionsLibTerms, RouteMeta, SecurityScheme } + +import java.net.URI object SpringMvcClientGenerator { def apply()(implicit Cl: CollectionsLibTerms[JavaLanguage, Target]): ClientTerms[JavaLanguage, Target] = @@ -17,57 +19,21 @@ object SpringMvcClientGenerator { } class SpringMvcClientGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, Target]) extends ClientTerms[JavaLanguage, Target] { - def MonadF = Target.targetInstances - def generateClientOperation( - className: List[String], - responseClsName: String, - tracing: Boolean, - securitySchemes: Map[String, SecurityScheme[JavaLanguage]], - parameters: LanguageParameters[JavaLanguage] - )( - route: RouteMeta, - methodName: String, - responses: Responses[JavaLanguage] - ) = - Target.raiseUserError("spring client generation is not currently supported") - def getImports(tracing: Boolean) = - Target.raiseUserError("spring client generation is not currently supported") - def getExtraImports(tracing: Boolean) = - Target.raiseUserError("spring client generation is not currently supported") - def clientClsArgs( - tracingName: Option[String], - serverUrls: Option[NonEmptyList[URI]], - tracing: Boolean - ) = - Target.raiseUserError("spring client generation is not currently supported") - def generateResponseDefinitions( - responseClsName: String, - responses: Responses[JavaLanguage], - protocolElems: List[StrictProtocolElems[JavaLanguage]] - ) = - Target.raiseUserError("spring client generation is not currently supported") - def generateSupportDefinitions( - tracing: Boolean, - securitySchemes: Map[String, SecurityScheme[JavaLanguage]] - ) = - Target.raiseUserError("spring client generation is not currently supported") - def buildStaticDefns( - clientName: String, - tracingName: Option[String], - serverUrls: Option[NonEmptyList[URI]], - ctorArgs: List[List[com.github.javaparser.ast.body.Parameter]], - tracing: Boolean - ) = - Target.raiseUserError("spring client generation is not currently supported") - def buildClient( - clientName: String, - tracingName: Option[String], + override def MonadF = Target.targetInstances + + override def fromSwagger(context: Context, frameworkImports: List[JavaLanguage#Import])( serverUrls: Option[NonEmptyList[URI]], basePath: Option[String], - ctorArgs: List[List[com.github.javaparser.ast.body.Parameter]], - clientCalls: List[com.github.javaparser.ast.body.BodyDeclaration[_ <: com.github.javaparser.ast.body.BodyDeclaration[_]]], - supportDefinitions: List[com.github.javaparser.ast.body.BodyDeclaration[_ <: com.github.javaparser.ast.body.BodyDeclaration[_]]], - tracing: Boolean - ) = + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[JavaLanguage]], + securitySchemes: Map[String, SecurityScheme[JavaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[Clients[JavaLanguage]] = Target.raiseUserError("spring client generation is not currently supported") } diff --git a/modules/scala-akka-http/src/main/scala/dev/guardrail/generators/scala/akkaHttp/AkkaHttpClientGenerator.scala b/modules/scala-akka-http/src/main/scala/dev/guardrail/generators/scala/akkaHttp/AkkaHttpClientGenerator.scala index 316592163a..f5d33a8485 100644 --- a/modules/scala-akka-http/src/main/scala/dev/guardrail/generators/scala/akkaHttp/AkkaHttpClientGenerator.scala +++ b/modules/scala-akka-http/src/main/scala/dev/guardrail/generators/scala/akkaHttp/AkkaHttpClientGenerator.scala @@ -1,32 +1,40 @@ package dev.guardrail.generators.scala.akkaHttp +import _root_.io.swagger.v3.oas.models.Components import _root_.io.swagger.v3.oas.models.PathItem.HttpMethod import cats.Monad import cats.data.NonEmptyList import cats.syntax.all._ -import java.net.URI -import scala.meta._ -import scala.reflect.runtime.universe.typeTag - -import dev.guardrail.Target -import dev.guardrail.core.{ SupportDefinition, Tracker } -import dev.guardrail.generators.{ LanguageParameter, LanguageParameters, RawParameterName, RenderedClientOperation } -import dev.guardrail.generators.scala.{ - CirceModelGenerator, - CirceRefinedModelGenerator, - JacksonModelGenerator, - ModelGeneratorType, - ResponseADTHelper, - ScalaLanguage -} +import dev.guardrail._ +import dev.guardrail.core.SupportDefinition +import dev.guardrail.core.Tracker +import dev.guardrail.generators.Client +import dev.guardrail.generators.Clients +import dev.guardrail.generators.LanguageParameter +import dev.guardrail.generators.LanguageParameters +import dev.guardrail.generators.RawParameterName +import dev.guardrail.generators.RenderedClientOperation +import dev.guardrail.generators.scala.CirceModelGenerator +import dev.guardrail.generators.scala.CirceRefinedModelGenerator +import dev.guardrail.generators.scala.JacksonModelGenerator +import dev.guardrail.generators.scala.ModelGeneratorType +import dev.guardrail.generators.scala.ResponseADTHelper +import dev.guardrail.generators.scala.ScalaLanguage import dev.guardrail.generators.scala.syntax._ -import dev.guardrail.generators.spi.{ ClientGeneratorLoader, ModuleLoadResult, ProtocolGeneratorLoader } +import dev.guardrail.generators.spi.ClientGeneratorLoader +import dev.guardrail.generators.spi.ModuleLoadResult +import dev.guardrail.generators.spi.ProtocolGeneratorLoader import dev.guardrail.generators.syntax._ import dev.guardrail.shims._ +import dev.guardrail.terms._ import dev.guardrail.terms.client.ClientTerms -import dev.guardrail.terms.protocol.{ StaticDefns, StrictProtocolElems } -import dev.guardrail.terms.{ ApplicationJson, ContentType, Header, MultipartFormData, Responses, TextPlain } -import dev.guardrail.terms.{ RouteMeta, SecurityScheme } +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol.StaticDefns +import dev.guardrail.terms.protocol.StrictProtocolElems + +import java.net.URI +import scala.meta._ +import scala.reflect.runtime.universe.typeTag class AkkaHttpClientGeneratorLoader extends ClientGeneratorLoader { type L = ScalaLanguage @@ -65,7 +73,69 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e serverUrls .fold(param"host: String")(v => param"host: String = ${Lit.String(v.head.toString())}") - override def generateClientOperation( + override def fromSwagger(context: Context, frameworkImports: List[ScalaLanguage#Import])( + serverUrls: Option[NonEmptyList[URI]], + basePath: Option[String], + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[ScalaLanguage]], + securitySchemes: Map[String, SecurityScheme[ScalaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Clients[ScalaLanguage]] = { + import Sc._ + import Sw._ + + for { + clientImports <- getImports(context.tracing) + clientExtraImports <- getExtraImports(context.tracing) + supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) + clients <- groupedRoutes.traverse { case (className, unsortedRoutes) => + val routes = unsortedRoutes.sortBy(r => (r.path.unwrapTracker, r.method)) + for { + clientName <- formatTypeName(className.lastOption.getOrElse(""), Some("Client")) + responseClientPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => + for { + operationId <- getOperationId(operation) + responses <- Responses.getResponses[ScalaLanguage, Target](operationId, operation, protocolElems, components) + responseClsName <- formatTypeName(operationId, Some("Response")) + responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) + parameters <- route.getParameters[ScalaLanguage, Target](components, protocolElems) + methodName <- formatMethodName(operationId) + clientOp <- generateClientOperation(className, responseClsName, context.tracing, securitySchemes, parameters)(route, methodName, responses) + } yield (responseDefinitions, clientOp) + } + (responseDefinitions, clientOperations) = responseClientPair.unzip + tracingName = Option(className.mkString("-")).filterNot(_.isEmpty) + ctorArgs <- clientClsArgs(tracingName, serverUrls, context.tracing) + staticDefns <- buildStaticDefns(clientName, tracingName, serverUrls, ctorArgs, context.tracing) + client <- buildClient( + clientName, + tracingName, + serverUrls, + basePath, + ctorArgs, + clientOperations.map(_.clientOperation), + clientOperations.flatMap(_.supportDefinitions), + context.tracing + ) + } yield Client[ScalaLanguage]( + className, + clientName, + clientImports ++ frameworkImports ++ clientExtraImports, + staticDefns, + client, + responseDefinitions.flatten + ) + } + } yield Clients[ScalaLanguage](clients, supportDefinitions) + } + + private def generateClientOperation( className: List[String], responseClsName: String, tracing: Boolean, @@ -420,9 +490,9 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e ) } yield renderedClientOperation) } - override def getImports(tracing: Boolean): Target[List[scala.meta.Import]] = Target.pure(List.empty) - override def getExtraImports(tracing: Boolean): Target[List[scala.meta.Import]] = Target.pure(List.empty) - override def clientClsArgs( + private def getImports(tracing: Boolean): Target[List[scala.meta.Import]] = Target.pure(List.empty) + private def getExtraImports(tracing: Boolean): Target[List[scala.meta.Import]] = Target.pure(List.empty) + private def clientClsArgs( tracingName: Option[String], serverUrls: Option[NonEmptyList[URI]], tracing: Boolean @@ -441,18 +511,18 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e implicits ++ protocolImplicits ) } - override def generateResponseDefinitions( + private def generateResponseDefinitions( responseClsName: String, responses: Responses[ScalaLanguage], protocolElems: List[StrictProtocolElems[ScalaLanguage]] ): Target[List[scala.meta.Defn]] = Target.pure(ResponseADTHelper.generateResponseDefinitions(responseClsName, responses, protocolElems)) - override def generateSupportDefinitions( + private def generateSupportDefinitions( tracing: Boolean, securitySchemes: Map[String, SecurityScheme[ScalaLanguage]] ): Target[List[SupportDefinition[ScalaLanguage]]] = Target.pure(List.empty) - override def buildStaticDefns( + private def buildStaticDefns( clientName: String, tracingName: Option[String], serverUrls: Option[NonEmptyList[URI]], @@ -508,7 +578,7 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e definitions = decls ) } - override def buildClient( + private def buildClient( clientName: String, tracingName: Option[String], serverUrls: Option[NonEmptyList[URI]], diff --git a/modules/scala-dropwizard/src/main/scala/dev/guardrail/generators/scala/dropwizard/DropwizardClientGenerator.scala b/modules/scala-dropwizard/src/main/scala/dev/guardrail/generators/scala/dropwizard/DropwizardClientGenerator.scala index d1f12760e5..86fb06aa3b 100644 --- a/modules/scala-dropwizard/src/main/scala/dev/guardrail/generators/scala/dropwizard/DropwizardClientGenerator.scala +++ b/modules/scala-dropwizard/src/main/scala/dev/guardrail/generators/scala/dropwizard/DropwizardClientGenerator.scala @@ -1,20 +1,23 @@ package dev.guardrail.generators.scala.dropwizard +import _root_.io.swagger.v3.oas.models.Components import cats.Monad import cats.data.NonEmptyList -import java.net.URI -import scala.meta.{ Defn, Import, Term } -import scala.reflect.runtime.universe.typeTag - -import dev.guardrail.core.SupportDefinition -import dev.guardrail.generators.{ LanguageParameters, RenderedClientOperation } +import dev.guardrail.Context +import dev.guardrail.RuntimeFailure +import dev.guardrail.Target +import dev.guardrail.core.Tracker +import dev.guardrail.generators.Clients import dev.guardrail.generators.scala.ScalaLanguage -import dev.guardrail.generators.spi.{ ClientGeneratorLoader, ModuleLoadResult } -import dev.guardrail.terms.Responses +import dev.guardrail.generators.spi.ClientGeneratorLoader +import dev.guardrail.generators.spi.ModuleLoadResult +import dev.guardrail.terms._ import dev.guardrail.terms.client.ClientTerms -import dev.guardrail.terms.protocol.{ StaticDefns, StrictProtocolElems } -import dev.guardrail.terms.{ RouteMeta, SecurityScheme } -import dev.guardrail.{ RuntimeFailure, Target } +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol.StrictProtocolElems + +import java.net.URI +import scala.reflect.runtime.universe.typeTag class DropwizardClientGeneratorLoader extends ClientGeneratorLoader { type L = ScalaLanguage @@ -28,43 +31,21 @@ object DropwizardClientGenerator { } class DropwizardClientGenerator private extends ClientTerms[ScalaLanguage, Target] { - override def MonadF: Monad[Target] = Target.targetInstances - override def getImports(tracing: Boolean): Target[List[Import]] = Target.raiseError(RuntimeFailure("Dropwizard Scala clients are not yet supported")) - override def getExtraImports(tracing: Boolean): Target[List[Import]] = Target.raiseError(RuntimeFailure("Dropwizard Scala clients are not yet supported")) - override def clientClsArgs(tracingName: Option[String], serverUrls: Option[NonEmptyList[URI]], tracing: Boolean): Target[List[List[Term.Param]]] = - Target.raiseError(RuntimeFailure("Dropwizard Scala clients are not yet supported")) - override def generateResponseDefinitions( - responseClsName: String, - responses: Responses[ScalaLanguage], - protocolElems: List[StrictProtocolElems[ScalaLanguage]] - ): Target[List[Defn]] = Target.raiseError(RuntimeFailure("Dropwizard Scala clients are not yet supported")) - override def generateSupportDefinitions( - tracing: Boolean, - securitySchemes: Map[String, SecurityScheme[ScalaLanguage]] - ): Target[List[SupportDefinition[ScalaLanguage]]] = Target.raiseError(RuntimeFailure("Dropwizard Scala clients are not yet supported")) - override def buildStaticDefns( - clientName: String, - tracingName: Option[String], + override def MonadF: Monad[Target] = Target.targetInstances + + override def fromSwagger(context: Context, frameworkImports: List[ScalaLanguage#Import])( serverUrls: Option[NonEmptyList[URI]], - ctorArgs: List[List[Term.Param]], - tracing: Boolean - ): Target[StaticDefns[ScalaLanguage]] = Target.raiseError(RuntimeFailure("Dropwizard Scala clients are not yet supported")) - override def generateClientOperation( - className: List[String], - responseClsName: String, - tracing: Boolean, + basePath: Option[String], + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[ScalaLanguage]], securitySchemes: Map[String, SecurityScheme[ScalaLanguage]], - parameters: LanguageParameters[ScalaLanguage] - )(route: RouteMeta, methodName: String, responses: Responses[ScalaLanguage]): Target[RenderedClientOperation[ScalaLanguage]] = + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Clients[ScalaLanguage]] = Target.raiseError(RuntimeFailure("Dropwizard Scala clients are not yet supported")) - override def buildClient( - clientName: String, - tracingName: Option[String], - serverUrls: Option[NonEmptyList[URI]], - basePath: Option[String], - ctorArgs: List[List[Term.Param]], - clientCalls: List[Defn], - supportDefinitions: List[Defn], - tracing: Boolean - ): Target[NonEmptyList[Either[Defn.Trait, Defn.Class]]] = Target.raiseError(RuntimeFailure("Dropwizard Scala clients are not yet supported")) } diff --git a/modules/scala-http4s/src/main/scala/dev/guardrail/generators/scala/http4s/Http4sClientGenerator.scala b/modules/scala-http4s/src/main/scala/dev/guardrail/generators/scala/http4s/Http4sClientGenerator.scala index 3d9d49ff6f..ab513221d4 100644 --- a/modules/scala-http4s/src/main/scala/dev/guardrail/generators/scala/http4s/Http4sClientGenerator.scala +++ b/modules/scala-http4s/src/main/scala/dev/guardrail/generators/scala/http4s/Http4sClientGenerator.scala @@ -1,25 +1,36 @@ package dev.guardrail.generators.scala.http4s +import _root_.io.swagger.v3.oas.models.Components import _root_.io.swagger.v3.oas.models.PathItem.HttpMethod import cats.Monad import cats.data.NonEmptyList import cats.syntax.all._ -import java.net.URI -import scala.meta._ -import scala.reflect.runtime.universe.typeTag -import scala.collection.immutable.Seq - -import dev.guardrail.Target -import dev.guardrail.core.{ SupportDefinition, Tracker } +import dev.guardrail._ +import dev.guardrail.core.SupportDefinition +import dev.guardrail.core.Tracker +import dev.guardrail.generators.Client +import dev.guardrail.generators.Clients +import dev.guardrail.generators.LanguageParameter +import dev.guardrail.generators.LanguageParameters +import dev.guardrail.generators.RawParameterName +import dev.guardrail.generators.RenderedClientOperation +import dev.guardrail.generators.scala.ResponseADTHelper +import dev.guardrail.generators.scala.ScalaLanguage import dev.guardrail.generators.scala.syntax._ -import dev.guardrail.generators.scala.{ ResponseADTHelper, ScalaLanguage } -import dev.guardrail.generators.spi.{ ClientGeneratorLoader, ModuleLoadResult } +import dev.guardrail.generators.spi.ClientGeneratorLoader +import dev.guardrail.generators.spi.ModuleLoadResult import dev.guardrail.generators.syntax._ -import dev.guardrail.generators.{ LanguageParameter, LanguageParameters, RawParameterName, RenderedClientOperation } import dev.guardrail.shims._ +import dev.guardrail.terms._ import dev.guardrail.terms.client.ClientTerms -import dev.guardrail.terms.protocol.{ StaticDefns, StrictProtocolElems } -import dev.guardrail.terms.{ ApplicationJson, ContentType, Header, MultipartFormData, OctetStream, Responses, RouteMeta, SecurityScheme, TextPlain } +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol.StaticDefns +import dev.guardrail.terms.protocol.StrictProtocolElems + +import java.net.URI +import scala.collection.immutable.Seq +import scala.meta._ +import scala.reflect.runtime.universe.typeTag class Http4sClientGeneratorLoader extends ClientGeneratorLoader { type L = ScalaLanguage @@ -40,7 +51,7 @@ class Http4sClientGenerator(version: Http4sVersion) extends ClientTerms[ScalaLan @deprecated("Please specify which http4s version to use", "0.72.0") def this() = this(Http4sVersion.V0_23) - implicit def MonadF: Monad[Target] = Target.targetInstances + override implicit def MonadF: Monad[Target] = Target.targetInstances def splitOperationParts(operationId: String): (List[String], String) = { val parts = operationId.split('.') @@ -56,6 +67,68 @@ class Http4sClientGenerator(version: Http4sVersion) extends ClientTerms[ScalaLan serverUrls .fold(param"host: String")(v => param"host: String = ${Lit.String(v.head.toString())}") + override def fromSwagger(context: Context, frameworkImports: List[ScalaLanguage#Import])( + serverUrls: Option[NonEmptyList[URI]], + basePath: Option[String], + groupedRoutes: List[(List[String], List[RouteMeta])] + )( + protocolElems: List[StrictProtocolElems[ScalaLanguage]], + securitySchemes: Map[String, SecurityScheme[ScalaLanguage]], + components: Tracker[Option[Components]] + )(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Clients[ScalaLanguage]] = { + import Sc._ + import Sw._ + + for { + clientImports <- getImports(context.tracing) + clientExtraImports <- getExtraImports(context.tracing) + supportDefinitions <- generateSupportDefinitions(context.tracing, securitySchemes) + clients <- groupedRoutes.traverse { case (className, unsortedRoutes) => + val routes = unsortedRoutes.sortBy(r => (r.path.unwrapTracker, r.method)) + for { + clientName <- formatTypeName(className.lastOption.getOrElse(""), Some("Client")) + responseClientPair <- routes.traverse { case route @ RouteMeta(path, method, operation, securityRequirements) => + for { + operationId <- getOperationId(operation) + responses <- Responses.getResponses[ScalaLanguage, Target](operationId, operation, protocolElems, components) + responseClsName <- formatTypeName(operationId, Some("Response")) + responseDefinitions <- generateResponseDefinitions(responseClsName, responses, protocolElems) + parameters <- route.getParameters[ScalaLanguage, Target](components, protocolElems) + methodName <- formatMethodName(operationId) + clientOp <- generateClientOperation(className, responseClsName, context.tracing, securitySchemes, parameters)(route, methodName, responses) + } yield (responseDefinitions, clientOp) + } + (responseDefinitions, clientOperations) = responseClientPair.unzip + tracingName = Option(className.mkString("-")).filterNot(_.isEmpty) + ctorArgs <- clientClsArgs(tracingName, serverUrls, context.tracing) + staticDefns <- buildStaticDefns(clientName, tracingName, serverUrls, ctorArgs, context.tracing) + client <- buildClient( + clientName, + tracingName, + serverUrls, + basePath, + ctorArgs, + clientOperations.map(_.clientOperation), + clientOperations.flatMap(_.supportDefinitions), + context.tracing + ) + } yield Client[ScalaLanguage]( + className, + clientName, + clientImports ++ frameworkImports ++ clientExtraImports, + staticDefns, + client, + responseDefinitions.flatten + ) + } + } yield Clients[ScalaLanguage](clients, supportDefinitions) + } + def generateClientOperation( className: List[String], responseClsName: String, From e69b7c9bafefcf84809fc2edbc201192e56b7fab Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 4 Oct 2023 11:15:22 -0700 Subject: [PATCH 6/6] Degeneralize protocol generator --- .../src/main/scala/dev/guardrail/Common.scala | 13 +- .../generators/ProtocolGenerator.scala | 832 +------- .../generators/protocol/ClassHierarchy.scala | 21 + .../dev/guardrail/terms/ProtocolTerms.scala | 254 +-- .../scala/dev/guardrail/terms/package.scala | 3 - .../java/jackson/JacksonGenerator.scala | 882 +++++++- .../scala/circe/CirceProtocolGenerator.scala | 867 +++++++- .../circe/CirceRefinedProtocolGenerator.scala | 1511 ++++++++++++- .../jackson/JacksonProtocolGenerator.scala | 1878 +++++++++++++---- .../tests/circe/ArrayValidationTest.scala | 9 +- .../scala/tests/circe/BigObjectSpec.scala | 5 +- .../scala/tests/circe/ValidationTest.scala | 5 +- 12 files changed, 4742 insertions(+), 1538 deletions(-) create mode 100644 modules/core/src/main/scala/dev/guardrail/generators/protocol/ClassHierarchy.scala delete mode 100644 modules/core/src/main/scala/dev/guardrail/terms/package.scala diff --git a/modules/core/src/main/scala/dev/guardrail/Common.scala b/modules/core/src/main/scala/dev/guardrail/Common.scala index 615860d9b9..765d8ca3b4 100644 --- a/modules/core/src/main/scala/dev/guardrail/Common.scala +++ b/modules/core/src/main/scala/dev/guardrail/Common.scala @@ -8,7 +8,8 @@ import java.nio.file.Path import java.net.URI import dev.guardrail.core.{ SupportDefinition, Tracker } -import dev.guardrail.generators.{ Clients, ProtocolDefinitions, ProtocolGenerator, Servers } +import dev.guardrail.generators.{ Clients, Servers } +import dev.guardrail.generators.ProtocolDefinitions import dev.guardrail.languages.LA import dev.guardrail.terms.client.ClientTerms import dev.guardrail.terms.framework.FrameworkTerms @@ -35,12 +36,16 @@ object Common { Se: ServerTerms[L, F], Sw: SwaggerTerms[L, F] ): F[(ProtocolDefinitions[L], CodegenDefinitions[L])] = { - import Fw._ + import Fw.{ getFrameworkImports, getFrameworkImplicits } import Sw._ Sw.log.function("prepareDefinitions")(for { - proto @ ProtocolDefinitions(protocolElems, protocolImports, packageObjectImports, packageObjectContents, _) <- ProtocolGenerator - .fromSwagger[L, F](swagger, dtoPackage, supportPackage, context.propertyRequirement) + proto @ ProtocolDefinitions(protocolElems, protocolImports, packageObjectImports, packageObjectContents, _) <- P.fromSwagger( + swagger, + dtoPackage, + supportPackage, + context.propertyRequirement + ) serverUrls = NonEmptyList.fromList( swagger diff --git a/modules/core/src/main/scala/dev/guardrail/generators/ProtocolGenerator.scala b/modules/core/src/main/scala/dev/guardrail/generators/ProtocolGenerator.scala index 0839bce59e..cde073c1c8 100644 --- a/modules/core/src/main/scala/dev/guardrail/generators/ProtocolGenerator.scala +++ b/modules/core/src/main/scala/dev/guardrail/generators/ProtocolGenerator.scala @@ -1,36 +1,9 @@ package dev.guardrail.generators -import _root_.io.swagger.v3.oas.models._ -import _root_.io.swagger.v3.oas.models.media.{ Discriminator => _, _ } -import cats.Monad -import cats.data.NonEmptyList -import cats.syntax.all._ -import dev.guardrail._ -import dev.guardrail.core.{ DataRedacted, DataVisible, EmptyIsEmpty, EmptyIsNull, Mappish, Tracker } -import dev.guardrail.core.implicits._ +import _root_.io.swagger.v3.oas.models.media.{ Discriminator => _, Schema } import dev.guardrail.languages.LA -import dev.guardrail.terms.protocol._ -import dev.guardrail.terms.framework.FrameworkTerms -import dev.guardrail.terms.{ - CollectionsLibTerms, - EnumSchema, - HeldEnum, - IntHeldEnum, - LanguageTerms, - LongHeldEnum, - NumberEnumSchema, - ObjectEnumSchema, - ProtocolTerms, - RenderedIntEnum, - RenderedLongEnum, - RenderedStringEnum, - StringEnumSchema, - StringHeldEnum, - SwaggerTerms -} -import cats.Foldable -import dev.guardrail.core.extract.Default -import scala.jdk.CollectionConverters._ +import dev.guardrail.terms.protocol.StrictProtocolElems +import dev.guardrail.terms.{ EnumSchema, NumberEnumSchema, ObjectEnumSchema, StringEnumSchema } case class ProtocolDefinitions[L <: LA]( elems: List[StrictProtocolElems[L]], @@ -41,807 +14,8 @@ case class ProtocolDefinitions[L <: LA]( ) object ProtocolGenerator { - private[this] def getRequiredFieldsRec(root: Tracker[Schema[_]]): List[String] = { - @scala.annotation.tailrec - def work(values: List[Tracker[Schema[_]]], acc: List[String]): List[String] = { - val required: List[String] = values.flatMap(_.downField("required", _.getRequired()).unwrapTracker) - val next: List[Tracker[Schema[_]]] = - for { - a <- values - b <- a.refine { case x: ComposedSchema => x }(_.downField("allOf", _.getAllOf())).toOption.toList - c <- b.indexedDistribute - } yield c - - val newRequired = acc ++ required - - next match { - case next @ (_ :: _) => work(next, newRequired) - case Nil => newRequired - } - } - work(List(root), Nil) - } - type WrapEnumSchema[A] = Schema[A] => EnumSchema implicit val wrapNumberEnumSchema: WrapEnumSchema[Number] = NumberEnumSchema.apply _ implicit val wrapObjectEnumSchema: WrapEnumSchema[Object] = ObjectEnumSchema.apply _ implicit val wrapStringEnumSchema: WrapEnumSchema[String] = StringEnumSchema.apply _ - - private[this] def fromEnum[L <: LA, F[_], A]( - clsName: String, - schema: Tracker[Schema[A]], - dtoPackage: List[String], - components: Tracker[Option[Components]] - )(implicit - P: ProtocolTerms[L, F], - F: FrameworkTerms[L, F], - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F], - Sw: SwaggerTerms[L, F], - wrapEnumSchema: WrapEnumSchema[A] - ): F[Either[String, EnumDefinition[L]]] = { - import P._ - import Sc._ - import Sw._ - - def validProg(held: HeldEnum, tpe: L#Type, fullType: L#Type): F[EnumDefinition[L]] = - for { - (pascalValues, wrappedValues) <- held match { - case StringHeldEnum(value) => - for { - elems <- value.traverse { elem => - for { - termName <- formatEnumName(elem) - valueTerm <- pureTermName(termName) - accessor <- buildAccessor(clsName, termName) - } yield (elem, valueTerm, accessor) - } - pascalValues = elems.map(_._2) - wrappedValues = RenderedStringEnum(elems) - } yield (pascalValues, wrappedValues) - case IntHeldEnum(value) => - for { - elems <- value.traverse { elem => - for { - termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms - valueTerm <- pureTermName(termName) - accessor <- buildAccessor(clsName, termName) - } yield (elem, valueTerm, accessor) - } - pascalValues = elems.map(_._2) - wrappedValues = RenderedIntEnum(elems) - } yield (pascalValues, wrappedValues) - case LongHeldEnum(value) => - for { - elems <- value.traverse { elem => - for { - termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms - valueTerm <- pureTermName(termName) - accessor <- buildAccessor(clsName, termName) - } yield (elem, valueTerm, accessor) - } - pascalValues = elems.map(_._2) - wrappedValues = RenderedLongEnum(elems) - } yield (pascalValues, wrappedValues) - } - members <- renderMembers(clsName, wrappedValues) - encoder <- encodeEnum(clsName, tpe) - decoder <- decodeEnum(clsName, tpe) - - defn <- renderClass(clsName, tpe, wrappedValues) - staticDefns <- renderStaticDefns(clsName, tpe, members, pascalValues, encoder, decoder) - classType <- pureTypeName(clsName) - } yield EnumDefinition[L](clsName, classType, fullType, wrappedValues, defn, staticDefns) - - for { - enum <- extractEnum(schema.map(wrapEnumSchema)) - customTpeName <- SwaggerUtil.customTypeName(schema) - (tpe, _) <- SwaggerUtil.determineTypeName(schema, Tracker.cloneHistory(schema, customTpeName), components) - fullType <- selectType(NonEmptyList.ofInitLast(dtoPackage, clsName)) - res <- enum.traverse(validProg(_, tpe, fullType)) - } yield res - } - - private[this] def getPropertyRequirement( - schema: Tracker[Schema[_]], - isRequired: Boolean, - defaultPropertyRequirement: PropertyRequirement - ): PropertyRequirement = - (for { - isNullable <- schema.downField("nullable", _.getNullable) - } yield (isRequired, isNullable) match { - case (true, None) => PropertyRequirement.Required - case (true, Some(false)) => PropertyRequirement.Required - case (true, Some(true)) => PropertyRequirement.RequiredNullable - case (false, None) => defaultPropertyRequirement - case (false, Some(false)) => PropertyRequirement.Optional - case (false, Some(true)) => PropertyRequirement.OptionalNullable - }).unwrapTracker - - /** Handle polymorphic model - */ - private[this] def fromPoly[L <: LA, F[_]]( - hierarchy: ClassParent[L], - concreteTypes: List[PropMeta[L]], - definitions: List[(String, Tracker[Schema[_]])], - dtoPackage: List[String], - supportPackage: List[String], - defaultPropertyRequirement: PropertyRequirement, - components: Tracker[Option[Components]] - )(implicit - F: FrameworkTerms[L, F], - P: ProtocolTerms[L, F], - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F], - Sw: SwaggerTerms[L, F] - ): F[ProtocolElems[L]] = { - import P._ - import Sc._ - - def child(hierarchy: ClassHierarchy[L]): List[String] = - hierarchy.children.map(_.name) ::: hierarchy.children.flatMap(child) - def parent(hierarchy: ClassHierarchy[L]): List[String] = - if (hierarchy.children.nonEmpty) hierarchy.name :: hierarchy.children.flatMap(parent) - else Nil - - val children = child(hierarchy).diff(parent(hierarchy)).distinct - val discriminator = hierarchy.discriminator - - for { - parents <- hierarchy.model - .refine[F[List[SuperClass[L]]]] { case c: ComposedSchema => c }( - extractParents(_, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) - ) - .getOrElse(List.empty[SuperClass[L]].pure[F]) - props <- extractProperties(hierarchy.model) - requiredFields = hierarchy.required ::: hierarchy.children.flatMap(_.required) - params <- props.traverse { case (name, prop) => - for { - typeName <- formatTypeName(name).map(formattedName => NonEmptyList.of(hierarchy.name, formattedName)) - propertyRequirement = getPropertyRequirement(prop, requiredFields.contains(name), defaultPropertyRequirement) - customType <- SwaggerUtil.customTypeName(prop) - resolvedType <- SwaggerUtil - .propMeta[L, F]( - prop, - components - ) // TODO: This should be resolved via an alternate mechanism that maintains references all the way through, instead of re-deriving and assuming that references are valid - defValue <- defaultValue(typeName, prop, propertyRequirement, definitions) - fieldName <- formatFieldName(name) - res <- transformProperty(hierarchy.name, dtoPackage, supportPackage, concreteTypes)( - name, - fieldName, - prop, - resolvedType, - propertyRequirement, - customType.isDefined, - defValue - ) - } yield res - } - definition <- renderSealedTrait(hierarchy.name, params, discriminator, parents, children) - encoder <- encodeADT(hierarchy.name, hierarchy.discriminator, children) - decoder <- decodeADT(hierarchy.name, hierarchy.discriminator, children) - staticDefns <- renderADTStaticDefns(hierarchy.name, discriminator, encoder, decoder) - tpe <- pureTypeName(hierarchy.name) - fullType <- selectType(NonEmptyList.fromList(dtoPackage :+ hierarchy.name).getOrElse(NonEmptyList.of(hierarchy.name))) - } yield ADT[L]( - name = hierarchy.name, - tpe = tpe, - fullType = fullType, - trt = definition, - staticDefns = staticDefns - ) - } - - private def extractParents[L <: LA, F[_]]( - elem: Tracker[ComposedSchema], - definitions: List[(String, Tracker[Schema[_]])], - concreteTypes: List[PropMeta[L]], - dtoPackage: List[String], - supportPackage: List[String], - defaultPropertyRequirement: PropertyRequirement, - components: Tracker[Option[Components]] - )(implicit - F: FrameworkTerms[L, F], - P: ProtocolTerms[L, F], - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F], - Sw: SwaggerTerms[L, F] - ): F[List[SuperClass[L]]] = { - import P._ - import Sc._ - - for { - a <- extractSuperClass(elem, definitions) - supper <- a.flatTraverse { case (clsName, _extends, interfaces) => - val concreteInterfacesWithClass = for { - interface <- interfaces - (cls, tracker) <- definitions - result <- tracker - .refine[Tracker[Schema[_]]] { - case x: ComposedSchema if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x - }( - identity _ - ) - .orRefine { case x: Schema[_] if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x }(identity _) - .toOption - } yield cls -> result - val (_, concreteInterfaces) = concreteInterfacesWithClass.unzip - val classMapping = (for { - (cls, schema) <- concreteInterfacesWithClass - (name, _) <- schema.downField("properties", _.getProperties).indexedDistribute.value - } yield (name, cls)).toMap - for { - _extendsProps <- extractProperties(_extends) - requiredFields = getRequiredFieldsRec(_extends) ++ concreteInterfaces.flatMap(getRequiredFieldsRec) - _withProps <- concreteInterfaces.traverse(extractProperties) - props = _extendsProps ++ _withProps.flatten - (params, _) <- prepareProperties( - NonEmptyList.of(clsName), - classMapping, - props, - requiredFields, - concreteTypes, - definitions, - dtoPackage, - supportPackage, - defaultPropertyRequirement, - components - ) - interfacesCls = interfaces.flatMap(_.downField("$ref", _.get$ref).unwrapTracker.map(_.split("/").last)) - tpe <- parseTypeName(clsName) - - discriminators <- (_extends :: concreteInterfaces).flatTraverse( - _.refine[F[List[Discriminator[L]]]] { case m: ObjectSchema => m }(m => Discriminator.fromSchema(m).map(_.toList)) - .getOrElse(List.empty[Discriminator[L]].pure[F]) - ) - } yield tpe - .map( - SuperClass[L]( - clsName, - _, - interfacesCls, - params, - discriminators - ) - ) - .toList - } - - } yield supper - } - - private[this] def fromModel[L <: LA, F[_]]( - clsName: NonEmptyList[String], - model: Tracker[Schema[_]], - parents: List[SuperClass[L]], - concreteTypes: List[PropMeta[L]], - definitions: List[(String, Tracker[Schema[_]])], - dtoPackage: List[String], - supportPackage: List[String], - defaultPropertyRequirement: PropertyRequirement, - components: Tracker[Option[Components]] - )(implicit - F: FrameworkTerms[L, F], - P: ProtocolTerms[L, F], - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F], - Sw: SwaggerTerms[L, F] - ): F[Either[String, ClassDefinition[L]]] = { - import P._ - import Sc._ - - for { - props <- extractProperties(model) - requiredFields = getRequiredFieldsRec(model) - (params, nestedDefinitions) <- prepareProperties( - clsName, - Map.empty, - props, - requiredFields, - concreteTypes, - definitions, - dtoPackage, - supportPackage, - defaultPropertyRequirement, - components - ) - encoder <- encodeModel(clsName.last, dtoPackage, params, parents) - decoder <- decodeModel(clsName.last, dtoPackage, supportPackage, params, parents) - tpe <- parseTypeName(clsName.last) - fullType <- selectType(dtoPackage.foldRight(clsName)((x, xs) => xs.prepend(x))) - staticDefns <- renderDTOStaticDefns(clsName.last, List.empty, encoder, decoder, params) - nestedClasses <- nestedDefinitions.flatTraverse { - case classDefinition: ClassDefinition[L] => - for { - widenClass <- widenClassDefinition(classDefinition.cls) - companionTerm <- pureTermName(classDefinition.name) - companionDefinition <- wrapToObject(companionTerm, classDefinition.staticDefns.extraImports, classDefinition.staticDefns.definitions) - widenCompanion <- companionDefinition.traverse(widenObjectDefinition) - } yield List(widenClass) ++ widenCompanion.fold(classDefinition.staticDefns.definitions)(List(_)) - case enumDefinition: EnumDefinition[L] => - for { - widenClass <- widenClassDefinition(enumDefinition.cls) - companionTerm <- pureTermName(enumDefinition.name) - companionDefinition <- wrapToObject(companionTerm, enumDefinition.staticDefns.extraImports, enumDefinition.staticDefns.definitions) - widenCompanion <- companionDefinition.traverse(widenObjectDefinition) - } yield List(widenClass) ++ widenCompanion.fold(enumDefinition.staticDefns.definitions)(List(_)) - } - defn <- renderDTOClass(clsName.last, supportPackage, params, parents) - } yield { - val finalStaticDefns = staticDefns.copy(definitions = staticDefns.definitions ++ nestedClasses) - if (parents.isEmpty && props.isEmpty) Left("Entity isn't model"): Either[String, ClassDefinition[L]] - else tpe.toRight("Empty entity name").map(ClassDefinition[L](clsName.last, _, fullType, defn, finalStaticDefns, parents)) - } - - } - - private def prepareProperties[L <: LA, F[_]]( - clsName: NonEmptyList[String], - propertyToTypeLookup: Map[String, String], - props: List[(String, Tracker[Schema[_]])], - requiredFields: List[String], - concreteTypes: List[PropMeta[L]], - definitions: List[(String, Tracker[Schema[_]])], - dtoPackage: List[String], - supportPackage: List[String], - defaultPropertyRequirement: PropertyRequirement, - components: Tracker[Option[Components]] - )(implicit - F: FrameworkTerms[L, F], - P: ProtocolTerms[L, F], - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F], - Sw: SwaggerTerms[L, F] - ): F[(List[ProtocolParameter[L]], List[NestedProtocolElems[L]])] = { - import P._ - import Sc._ - def getClsName(name: String): NonEmptyList[String] = propertyToTypeLookup.get(name).map(NonEmptyList.of(_)).getOrElse(clsName) - - def processProperty(name: String, schema: Tracker[Schema[_]]): F[Option[Either[String, NestedProtocolElems[L]]]] = - for { - nestedClassName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) - defn <- schema - .refine[F[Option[Either[String, NestedProtocolElems[L]]]]] { case x: ObjectSchema => x }(o => - for { - defn <- fromModel( - nestedClassName, - o, - List.empty, - concreteTypes, - definitions, - dtoPackage, - supportPackage, - defaultPropertyRequirement, - components - ) - } yield Option(defn) - ) - .orRefine { case o: ComposedSchema => o }(o => - for { - parents <- extractParents(o, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) - maybeClassDefinition <- fromModel( - nestedClassName, - o, - parents, - concreteTypes, - definitions, - dtoPackage, - supportPackage, - defaultPropertyRequirement, - components - ) - } yield Option(maybeClassDefinition) - ) - .orRefine { case a: ArraySchema => a }(_.downField("items", _.getItems()).indexedCosequence.flatTraverse(processProperty(name, _))) - .orRefine { case s: StringSchema if Option(s.getEnum).map(_.asScala).exists(_.nonEmpty) => s }(s => - fromEnum(nestedClassName.last, s, dtoPackage, components).map(Option(_)) - ) - .getOrElse(Option.empty[Either[String, NestedProtocolElems[L]]].pure[F]) - } yield defn - - for { - paramsAndNestedDefinitions <- props.traverse[F, (Tracker[ProtocolParameter[L]], Option[NestedProtocolElems[L]])] { case (name, schema) => - for { - typeName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) - tpe <- selectType(typeName) - maybeNestedDefinition <- processProperty(name, schema) - resolvedType <- SwaggerUtil.propMetaWithName(tpe, schema, components) - customType <- SwaggerUtil.customTypeName(schema) - propertyRequirement = getPropertyRequirement(schema, requiredFields.contains(name), defaultPropertyRequirement) - defValue <- defaultValue(typeName, schema, propertyRequirement, definitions) - fieldName <- formatFieldName(name) - parameter <- transformProperty(getClsName(name).last, dtoPackage, supportPackage, concreteTypes)( - name, - fieldName, - schema, - resolvedType, - propertyRequirement, - customType.isDefined, - defValue - ) - } yield (Tracker.cloneHistory(schema, parameter), maybeNestedDefinition.flatMap(_.toOption)) - } - (params, nestedDefinitions) = paramsAndNestedDefinitions.unzip - deduplicatedParams <- deduplicateParams(params) - unconflictedParams <- fixConflictingNames(deduplicatedParams) - } yield (unconflictedParams, nestedDefinitions.flatten) - } - - private def deduplicateParams[L <: LA, F[_]]( - params: List[Tracker[ProtocolParameter[L]]] - )(implicit Sw: SwaggerTerms[L, F], Sc: LanguageTerms[L, F]): F[List[ProtocolParameter[L]]] = { - import Sc._ - Foldable[List] - .foldLeftM[F, Tracker[ProtocolParameter[L]], List[ProtocolParameter[L]]](params, List.empty[ProtocolParameter[L]]) { (s, ta) => - val a = ta.unwrapTracker - s.find(p => p.name == a.name) match { - case None => (a :: s).pure[F] - case Some(duplicate) => - for { - newDefaultValue <- findCommonDefaultValue(ta.showHistory, a.defaultValue, duplicate.defaultValue) - newRawType <- findCommonRawType(ta.showHistory, a.rawType, duplicate.rawType) - } yield { - val emptyToNull = if (Set(a.emptyToNull, duplicate.emptyToNull).contains(EmptyIsNull)) EmptyIsNull else EmptyIsEmpty - val redactionBehaviour = if (Set(a.dataRedaction, duplicate.dataRedaction).contains(DataRedacted)) DataRedacted else DataVisible - val mergedParameter = ProtocolParameter[L]( - a.term, - a.baseType, - a.name, - a.dep, - newRawType, - a.readOnlyKey.orElse(duplicate.readOnlyKey), - emptyToNull, - redactionBehaviour, - a.propertyRequirement, - newDefaultValue, - a.propertyValidation - ) - mergedParameter :: s.filter(_.name != a.name) - } - } - } - .map(_.reverse) - } - - private def fixConflictingNames[L <: LA, F[_]](params: List[ProtocolParameter[L]])(implicit Lt: LanguageTerms[L, F]): F[List[ProtocolParameter[L]]] = { - import Lt._ - for { - paramsWithNames <- params.traverse(param => extractTermNameFromParam(param.term).map((_, param))) - counts = paramsWithNames.groupBy(_._1).view.mapValues(_.length).toMap - newParams <- paramsWithNames.traverse { case (name, param) => - if (counts.getOrElse(name, 0) > 1) { - for { - newTermName <- pureTermName(param.name.value) - newMethodParam <- alterMethodParameterName(param.term, newTermName) - } yield ProtocolParameter( - newMethodParam, - param.baseType, - param.name, - param.dep, - param.rawType, - param.readOnlyKey, - param.emptyToNull, - param.dataRedaction, - param.propertyRequirement, - param.defaultValue, - param.propertyValidation - ) - } else { - param.pure[F] - } - } - } yield newParams - } - - def modelTypeAlias[L <: LA, F[_]](clsName: String, abstractModel: Tracker[Schema[_]], components: Tracker[Option[Components]])(implicit - Fw: FrameworkTerms[L, F], - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F], - Sw: SwaggerTerms[L, F] - ): F[ProtocolElems[L]] = { - import Fw._ - val model: Option[Tracker[ObjectSchema]] = abstractModel - .refine[Option[Tracker[ObjectSchema]]] { case m: ObjectSchema => m }(x => Option(x)) - .orRefine { case m: ComposedSchema => m }( - _.downField("allOf", _.getAllOf()).indexedCosequence - .get(1) - .flatMap( - _.refine { case o: ObjectSchema => o }(Option.apply) - .orRefineFallback(_ => None) - ) - ) - .orRefineFallback(_ => None) - for { - tpe <- model.fold[F[L#Type]](objectType(None)) { m => - for { - tpeName <- SwaggerUtil.customTypeName[L, F, Tracker[ObjectSchema]](m) - (declType, _) <- SwaggerUtil.determineTypeName[L, F](m, Tracker.cloneHistory(m, tpeName), components) - } yield declType - } - res <- typeAlias[L, F](clsName, tpe) - } yield res - } - - def plainTypeAlias[L <: LA, F[_]]( - clsName: String - )(implicit Fw: FrameworkTerms[L, F], Sc: LanguageTerms[L, F]): F[ProtocolElems[L]] = { - import Fw._ - for { - tpe <- objectType(None) - res <- typeAlias[L, F](clsName, tpe) - } yield res - } - - def typeAlias[L <: LA, F[_]: Monad](clsName: String, tpe: L#Type): F[ProtocolElems[L]] = - (RandomType[L](clsName, tpe): ProtocolElems[L]).pure[F] - - def fromArray[L <: LA, F[_]](clsName: String, arr: Tracker[ArraySchema], concreteTypes: List[PropMeta[L]], components: Tracker[Option[Components]])(implicit - F: FrameworkTerms[L, F], - P: ProtocolTerms[L, F], - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F], - Sw: SwaggerTerms[L, F] - ): F[ProtocolElems[L]] = { - import P._ - for { - deferredTpe <- SwaggerUtil.modelMetaType(arr, components) - tpe <- extractArrayType(deferredTpe, concreteTypes) - ret <- typeAlias[L, F](clsName, tpe) - } yield ret - } - - sealed trait ClassHierarchy[L <: LA] { - def name: String - def model: Tracker[Schema[_]] - def children: List[ClassChild[L]] - def required: List[String] - } - case class ClassChild[L <: LA](name: String, model: Tracker[Schema[_]], children: List[ClassChild[L]], required: List[String]) extends ClassHierarchy[L] - case class ClassParent[L <: LA]( - name: String, - model: Tracker[Schema[_]], - children: List[ClassChild[L]], - discriminator: Discriminator[L], - required: List[String] - ) extends ClassHierarchy[L] - - /** returns objects grouped into hierarchies - */ - def groupHierarchies[L <: LA, F[_]]( - definitions: Mappish[List, String, Tracker[Schema[_]]] - )(implicit Sc: LanguageTerms[L, F], Sw: SwaggerTerms[L, F]): F[(List[ClassParent[L]], List[(String, Tracker[Schema[_]])])] = { - - def firstInHierarchy(model: Tracker[Schema[_]]): Option[Tracker[ObjectSchema]] = - model - .refine { case x: ComposedSchema => x } { elem => - definitions.value - .collectFirst { - case (clsName, element) - if elem.downField("allOf", _.getAllOf).exists(_.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$clsName"))) => - element - } - .flatMap( - _.refine { case x: ComposedSchema => x }(firstInHierarchy) - .orRefine { case o: ObjectSchema => o }(x => Option(x)) - .getOrElse(None) - ) - } - .getOrElse(None) - - def children(cls: String): List[ClassChild[L]] = definitions.value.flatMap { case (clsName, comp) => - comp - .refine { case x: ComposedSchema => x }(comp => - if ( - comp - .downField("allOf", _.getAllOf()) - .exists(x => x.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$cls"))) - ) { - Some(ClassChild(clsName, comp, children(clsName), getRequiredFieldsRec(comp))) - } else None - ) - .getOrElse(None) - } - - def classHierarchy(cls: String, model: Tracker[Schema[_]]): F[Option[ClassParent[L]]] = - model - .refine { case c: ComposedSchema => c }(c => - firstInHierarchy(c) - .fold(Option.empty[Discriminator[L]].pure[F])(Discriminator.fromSchema[L, F]) - .map(_.map((_, getRequiredFieldsRec(c)))) - ) - .orRefine { case x: Schema[_] => x }(m => Discriminator.fromSchema(m).map(_.map((_, getRequiredFieldsRec(m))))) - .getOrElse(Option.empty[(Discriminator[L], List[String])].pure[F]) - .map(_.map { case (discriminator, reqFields) => ClassParent(cls, model, children(cls), discriminator, reqFields) }) - - Sw.log.function("groupHierarchies")( - definitions.value - .traverse { case (cls, model) => - for { - hierarchy <- classHierarchy(cls, model) - } yield hierarchy.filterNot(_.children.isEmpty).toLeft((cls, model)) - } - .map(_.partitionEither[ClassParent[L], (String, Tracker[Schema[_]])](identity)) - ) - } - - def fromSwagger[L <: LA, F[_]]( - swagger: Tracker[OpenAPI], - dtoPackage: List[String], - supportPackage: NonEmptyList[String], - defaultPropertyRequirement: PropertyRequirement - )(implicit - F: FrameworkTerms[L, F], - P: ProtocolTerms[L, F], - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F], - Sw: SwaggerTerms[L, F] - ): F[ProtocolDefinitions[L]] = { - import P._ - import Sc._ - - val components = swagger.downField("components", _.getComponents()) - val definitions = components.flatDownField("schemas", _.getSchemas()).indexedCosequence - Sw.log.function("ProtocolGenerator.fromSwagger")(for { - (hierarchies, definitionsWithoutPoly) <- groupHierarchies(definitions) - - concreteTypes <- SwaggerUtil.extractConcreteTypes[L, F](definitions.value, components) - polyADTs <- hierarchies.traverse(fromPoly(_, concreteTypes, definitions.value, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components)) - elems <- definitionsWithoutPoly.traverse { case (clsName, model) => - model - .refine { case c: ComposedSchema => c }(comp => - for { - formattedClsName <- formatTypeName(clsName) - parents <- extractParents(comp, definitions.value, concreteTypes, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components) - model <- fromModel( - clsName = NonEmptyList.of(formattedClsName), - model = comp, - parents = parents, - concreteTypes = concreteTypes, - definitions = definitions.value, - dtoPackage = dtoPackage, - supportPackage = supportPackage.toList, - defaultPropertyRequirement = defaultPropertyRequirement, - components = components - ) - alias <- modelTypeAlias(formattedClsName, comp, components) - } yield model.getOrElse(alias) - ) - .orRefine { case a: ArraySchema => a }(arr => - for { - formattedClsName <- formatTypeName(clsName) - array <- fromArray(formattedClsName, arr, concreteTypes, components) - } yield array - ) - .orRefine { case o: ObjectSchema => o }(m => - for { - formattedClsName <- formatTypeName(clsName) - enum <- fromEnum(formattedClsName, m, dtoPackage, components) - model <- fromModel( - NonEmptyList.of(formattedClsName), - m, - List.empty, - concreteTypes, - definitions.value, - dtoPackage, - supportPackage.toList, - defaultPropertyRequirement, - components - ) - alias <- modelTypeAlias(formattedClsName, m, components) - } yield enum.orElse(model).getOrElse(alias) - ) - .orRefine { case x: StringSchema => x }(x => - for { - formattedClsName <- formatTypeName(clsName) - enum <- fromEnum(formattedClsName, x, dtoPackage, components) - model <- fromModel( - NonEmptyList.of(formattedClsName), - x, - List.empty, - concreteTypes, - definitions.value, - dtoPackage, - supportPackage.toList, - defaultPropertyRequirement, - components - ) - customTypeName <- SwaggerUtil.customTypeName(x) - (declType, _) <- SwaggerUtil.determineTypeName[L, F](x, Tracker.cloneHistory(x, customTypeName), components) - alias <- typeAlias[L, F](formattedClsName, declType) - } yield enum.orElse(model).getOrElse(alias) - ) - .orRefine { case x: IntegerSchema => x }(x => - for { - formattedClsName <- formatTypeName(clsName) - enum <- fromEnum(formattedClsName, x, dtoPackage, components) - model <- fromModel( - NonEmptyList.of(formattedClsName), - x, - List.empty, - concreteTypes, - definitions.value, - dtoPackage, - supportPackage.toList, - defaultPropertyRequirement, - components - ) - customTypeName <- SwaggerUtil.customTypeName(x) - (declType, _) <- SwaggerUtil.determineTypeName[L, F](x, Tracker.cloneHistory(x, customTypeName), components) - alias <- typeAlias[L, F](formattedClsName, declType) - } yield enum.orElse(model).getOrElse(alias) - ) - .valueOr(x => - for { - formattedClsName <- formatTypeName(clsName) - customTypeName <- SwaggerUtil.customTypeName(x) - (declType, _) <- SwaggerUtil.determineTypeName[L, F](x, Tracker.cloneHistory(x, customTypeName), components) - res <- typeAlias[L, F](formattedClsName, declType) - } yield res - ) - } - protoImports <- protocolImports() - pkgImports <- packageObjectImports() - pkgObjectContents <- packageObjectContents() - implicitsObject <- implicitsObject() - - polyADTElems <- ProtocolElems.resolve[L, F](polyADTs) - strictElems <- ProtocolElems.resolve[L, F](elems) - } yield ProtocolDefinitions[L](strictElems ++ polyADTElems, protoImports, pkgImports, pkgObjectContents, implicitsObject)) - } - - private def defaultValue[L <: LA, F[_]]( - name: NonEmptyList[String], - schema: Tracker[Schema[_]], - requirement: PropertyRequirement, - definitions: List[(String, Tracker[Schema[_]])] - )(implicit - Sc: LanguageTerms[L, F], - Cl: CollectionsLibTerms[L, F] - ): F[Option[L#Term]] = { - import Sc._ - import Cl._ - val empty = Option.empty[L#Term].pure[F] - schema.downField("$ref", _.get$ref()).indexedDistribute match { - case Some(ref) => - definitions - .collectFirst { - case (cls, refSchema) if ref.unwrapTracker.endsWith(s"/$cls") => - defaultValue(NonEmptyList.of(cls), refSchema, requirement, definitions) - } - .getOrElse(empty) - case None => - schema - .refine { case map: MapSchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => map }(map => - for { - customTpe <- SwaggerUtil.customMapTypeName(map) - result <- customTpe.fold(emptyMap().map(Option(_)))(_ => empty) - } yield result - ) - .orRefine { case arr: ArraySchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => arr }( - arr => - for { - customTpe <- SwaggerUtil.customArrayTypeName(arr) - result <- customTpe.fold(emptyArray().map(Option(_)))(_ => empty) - } yield result - ) - .orRefine { case p: BooleanSchema => p }(p => Default(p).extract[Boolean].fold(empty)(litBoolean(_).map(Some(_)))) - .orRefine { case p: NumberSchema if p.getFormat == "double" => p }(p => Default(p).extract[Double].fold(empty)(litDouble(_).map(Some(_)))) - .orRefine { case p: NumberSchema if p.getFormat == "float" => p }(p => Default(p).extract[Float].fold(empty)(litFloat(_).map(Some(_)))) - .orRefine { case p: IntegerSchema if p.getFormat == "int32" => p }(p => Default(p).extract[Int].fold(empty)(litInt(_).map(Some(_)))) - .orRefine { case p: IntegerSchema if p.getFormat == "int64" => p }(p => Default(p).extract[Long].fold(empty)(litLong(_).map(Some(_)))) - .orRefine { case p: StringSchema if Option(p.getEnum).map(_.asScala).exists(_.nonEmpty) => p }(p => - Default(p).extract[String] match { - case Some(defaultEnumValue) => - for { - enumName <- formatEnumName(defaultEnumValue) - result <- selectTerm(name.append(enumName)) - } yield Some(result) - case None => empty - } - ) - .orRefine { case p: StringSchema => p }(p => Default(p).extract[String].fold(empty)(litString(_).map(Some(_)))) - .getOrElse(empty) - } - } } diff --git a/modules/core/src/main/scala/dev/guardrail/generators/protocol/ClassHierarchy.scala b/modules/core/src/main/scala/dev/guardrail/generators/protocol/ClassHierarchy.scala new file mode 100644 index 0000000000..66e7331cd3 --- /dev/null +++ b/modules/core/src/main/scala/dev/guardrail/generators/protocol/ClassHierarchy.scala @@ -0,0 +1,21 @@ +package dev.guardrail.generators.protocol + +import _root_.io.swagger.v3.oas.models.media.Schema +import dev.guardrail.languages.LA +import dev.guardrail.core.Tracker +import dev.guardrail.terms.protocol.Discriminator + +sealed trait ClassHierarchy[L <: LA] { + def name: String + def model: Tracker[Schema[_]] + def children: List[ClassChild[L]] + def required: List[String] +} +case class ClassChild[L <: LA](name: String, model: Tracker[Schema[_]], children: List[ClassChild[L]], required: List[String]) extends ClassHierarchy[L] +case class ClassParent[L <: LA]( + name: String, + model: Tracker[Schema[_]], + children: List[ClassChild[L]], + discriminator: Discriminator[L], + required: List[String] +) extends ClassHierarchy[L] diff --git a/modules/core/src/main/scala/dev/guardrail/terms/ProtocolTerms.scala b/modules/core/src/main/scala/dev/guardrail/terms/ProtocolTerms.scala index 40593a2a36..778f1746c7 100644 --- a/modules/core/src/main/scala/dev/guardrail/terms/ProtocolTerms.scala +++ b/modules/core/src/main/scala/dev/guardrail/terms/ProtocolTerms.scala @@ -1,254 +1,34 @@ package dev.guardrail.terms import cats.Monad -import io.swagger.v3.oas.models.media.{ ComposedSchema, Schema } -import dev.guardrail.core.{ ResolvedType, SupportDefinition, Tracker } +import cats.data.NonEmptyList +import _root_.io.swagger.v3.oas.models.OpenAPI +import dev.guardrail.core.{ SupportDefinition, Tracker } import dev.guardrail.languages.LA -import dev.guardrail.terms.protocol.{ Discriminator, PropMeta, PropertyRequirement, ProtocolParameter, StaticDefns, SuperClass } +import dev.guardrail.terms.protocol.PropertyRequirement +import dev.guardrail.generators.ProtocolDefinitions +import dev.guardrail.terms.framework.FrameworkTerms import scala.collection.immutable.List abstract class ProtocolTerms[L <: LA, F[_]] { self => def MonadF: Monad[F] - // ProtocolTerms - def extractConcreteTypes(models: Either[String, List[PropMeta[L]]]): F[List[PropMeta[L]]] def staticProtocolImports(pkgName: List[String]): F[List[L#Import]] - def protocolImports(): F[List[L#Import]] - def packageObjectImports(): F[List[L#Import]] - def packageObjectContents(): F[List[L#Statement]] - def implicitsObject(): F[Option[(L#TermName, L#ObjectDefinition)]] def generateSupportDefinitions(): F[List[SupportDefinition[L]]] - // ModelProtocolTerms - def extractProperties(swagger: Tracker[Schema[_]]): F[List[(String, Tracker[Schema[_]])]] - def transformProperty(clsName: String, dtoPackage: List[String], supportPackage: List[String], concreteTypes: List[PropMeta[L]])( - name: String, - fieldName: String, - prop: Tracker[Schema[_]], - meta: ResolvedType[L], - requirement: PropertyRequirement, - isCustomType: Boolean, - defaultValue: Option[L#Term] - ): F[ProtocolParameter[L]] - def renderDTOClass(clsName: String, supportPackage: List[String], terms: List[ProtocolParameter[L]], parents: List[SuperClass[L]] = Nil): F[L#ClassDefinition] - def encodeModel( - clsName: String, + def fromSwagger( + swagger: Tracker[OpenAPI], dtoPackage: List[String], - params: List[ProtocolParameter[L]], - parents: List[SuperClass[L]] = Nil - ): F[Option[L#ValueDefinition]] - def decodeModel( - clsName: String, - dtoPackage: List[String], - supportPackage: List[String], - params: List[ProtocolParameter[L]], - parents: List[SuperClass[L]] = Nil - ): F[Option[L#ValueDefinition]] - def renderDTOStaticDefns( - clsName: String, - deps: List[L#TermName], - encoder: Option[L#ValueDefinition], - decoder: Option[L#ValueDefinition], - protocolParameters: List[ProtocolParameter[L]] - ): F[StaticDefns[L]] - - // ArrayProtocolTerms - def extractArrayType(arr: ResolvedType[L], concreteTypes: List[PropMeta[L]]): F[L#Type] + supportPackage: NonEmptyList[String], + defaultPropertyRequirement: PropertyRequirement + )(implicit + F: FrameworkTerms[L, F], + P: ProtocolTerms[L, F], + Sc: LanguageTerms[L, F], + Cl: CollectionsLibTerms[L, F], + Sw: SwaggerTerms[L, F] + ): F[ProtocolDefinitions[L]] - // EnumProtocolTerms - def renderMembers(clsName: String, elems: RenderedEnum[L]): F[Option[L#ObjectDefinition]] - def encodeEnum(clsName: String, tpe: L#Type): F[Option[L#Definition]] - def decodeEnum(clsName: String, tpe: L#Type): F[Option[L#Definition]] - def renderClass(clsName: String, tpe: L#Type, elems: RenderedEnum[L]): F[L#ClassDefinition] - def renderStaticDefns( - clsName: String, - tpe: L#Type, - members: Option[L#ObjectDefinition], - accessors: List[L#TermName], - encoder: Option[L#Definition], - decoder: Option[L#Definition] - ): F[StaticDefns[L]] def buildAccessor(clsName: String, termName: String): F[L#TermSelect] - - // PolyProtocolTerms - def extractSuperClass( - swagger: Tracker[ComposedSchema], - definitions: List[(String, Tracker[Schema[_]])] - ): F[List[(String, Tracker[Schema[_]], List[Tracker[Schema[_]]])]] - def renderSealedTrait( - className: String, - params: List[ProtocolParameter[L]], - discriminator: Discriminator[L], - parents: List[SuperClass[L]] = Nil, - children: List[String] = Nil - ): F[L#Trait] - def encodeADT(clsName: String, discriminator: Discriminator[L], children: List[String] = Nil): F[Option[L#ValueDefinition]] - def decodeADT(clsName: String, discriminator: Discriminator[L], children: List[String] = Nil): F[Option[L#ValueDefinition]] - def renderADTStaticDefns( - clsName: String, - discriminator: Discriminator[L], - encoder: Option[L#ValueDefinition], - decoder: Option[L#ValueDefinition] - ): F[StaticDefns[L]] - - def copy( - MonadF: Monad[F] = self.MonadF, - extractConcreteTypes: Either[String, List[PropMeta[L]]] => F[List[PropMeta[L]]] = self.extractConcreteTypes, - staticProtocolImports: ((List[String]) => F[List[L#Import]]) = self.staticProtocolImports _, - protocolImports: (() => F[List[L#Import]]) = self.protocolImports _, - packageObjectImports: (() => F[List[L#Import]]) = self.packageObjectImports _, - packageObjectContents: (() => F[List[L#Statement]]) = self.packageObjectContents _, - implicitsObject: () => F[Option[(L#TermName, L#ObjectDefinition)]] = self.implicitsObject _, - generateSupportDefinitions: (() => F[List[SupportDefinition[L]]]) = self.generateSupportDefinitions _, - extractProperties: Tracker[Schema[_]] => F[List[(String, Tracker[Schema[_]])]] = self.extractProperties _, - transformProperty: ( - String, - List[String], - List[String], - List[PropMeta[L]] - ) => (String, String, Tracker[Schema[_]], ResolvedType[L], PropertyRequirement, Boolean, Option[L#Term]) => F[ProtocolParameter[L]] = - self.transformProperty _, - renderDTOClass: (String, List[String], List[ProtocolParameter[L]], List[SuperClass[L]]) => F[L#ClassDefinition] = self.renderDTOClass _, - decodeModel: (String, List[String], List[String], List[ProtocolParameter[L]], List[SuperClass[L]]) => F[Option[L#ValueDefinition]] = self.decodeModel _, - encodeModel: (String, List[String], List[ProtocolParameter[L]], List[SuperClass[L]]) => F[Option[L#ValueDefinition]] = self.encodeModel _, - renderDTOStaticDefns: (String, List[L#TermName], Option[L#ValueDefinition], Option[L#ValueDefinition], List[ProtocolParameter[L]]) => F[StaticDefns[L]] = - self.renderDTOStaticDefns _, - extractArrayType: (ResolvedType[L], List[PropMeta[L]]) => F[L#Type] = self.extractArrayType _, - extractSuperClass: (Tracker[ComposedSchema], List[(String, Tracker[Schema[_]])]) => F[List[(String, Tracker[Schema[_]], List[Tracker[Schema[_]]])]] = - self.extractSuperClass _, - renderSealedTrait: (String, List[ProtocolParameter[L]], Discriminator[L], List[SuperClass[L]], List[String]) => F[L#Trait] = self.renderSealedTrait _, - encodeADT: (String, Discriminator[L], List[String]) => F[Option[L#ValueDefinition]] = self.encodeADT _, - decodeADT: (String, Discriminator[L], List[String]) => F[Option[L#ValueDefinition]] = self.decodeADT _, - renderADTStaticDefns: (String, Discriminator[L], Option[L#ValueDefinition], Option[L#ValueDefinition]) => F[StaticDefns[L]] = self.renderADTStaticDefns _, - renderMembers: (String, RenderedEnum[L]) => F[Option[L#ObjectDefinition]] = self.renderMembers _, - encodeEnum: (String, L#Type) => F[Option[L#Definition]] = self.encodeEnum _, - decodeEnum: (String, L#Type) => F[Option[L#Definition]] = self.decodeEnum _, - renderClass: (String, L#Type, RenderedEnum[L]) => F[L#ClassDefinition] = self.renderClass _, - renderStaticDefns: (String, L#Type, Option[L#ObjectDefinition], List[L#TermName], Option[L#Definition], Option[L#Definition]) => F[StaticDefns[L]] = - self.renderStaticDefns _, - buildAccessor: (String, String) => F[L#TermSelect] = self.buildAccessor _ - ): ProtocolTerms[L, F] = { - val newMonadF = MonadF - val newExtractConcreteTypes = extractConcreteTypes - val newStaticProtocolImports = staticProtocolImports - val newProtocolImports = protocolImports - val newPackageObjectImports = packageObjectImports - val newPackageObjectContents = packageObjectContents - val newImplicitsObject = implicitsObject - val newGenerateSupportDefinitions = generateSupportDefinitions - - val newExtractProperties = extractProperties - val newTransformProperty = transformProperty - val newRenderDTOClass = renderDTOClass - val newDecodeModel = decodeModel - val newEncodeModel = encodeModel - val newRenderDTOStaticDefns = renderDTOStaticDefns - - val newExtractArrayType = extractArrayType - - val newExtractSuperClass = extractSuperClass - val newRenderSealedTrait = renderSealedTrait - val newEncodeADT = encodeADT - val newDecodeADT = decodeADT - val newRenderADTStaticDefns = renderADTStaticDefns - - val newRenderMembers = renderMembers - val newEncodeEnum = encodeEnum - val newDecodeEnum = decodeEnum - val newRenderClass = renderClass - val newRenderStaticDefns = renderStaticDefns - val newBuildAccessor = buildAccessor - - new ProtocolTerms[L, F] { - def MonadF = newMonadF - def extractConcreteTypes(models: Either[String, List[PropMeta[L]]]) = newExtractConcreteTypes(models) - def staticProtocolImports(pkgName: List[String]) = newStaticProtocolImports(pkgName) - def protocolImports() = newProtocolImports() - def packageObjectImports() = newPackageObjectImports() - def packageObjectContents() = newPackageObjectContents() - def implicitsObject() = newImplicitsObject() - def generateSupportDefinitions() = newGenerateSupportDefinitions() - - def extractProperties(swagger: Tracker[Schema[_]]) = newExtractProperties(swagger) - def transformProperty( - clsName: String, - dtoPackage: List[String], - supportPackage: List[String], - concreteTypes: List[PropMeta[L]] - )( - name: String, - fieldName: String, - prop: Tracker[Schema[_]], - meta: ResolvedType[L], - requirement: PropertyRequirement, - isCustomType: Boolean, - defaultValue: Option[L#Term] - ) = - newTransformProperty(clsName, dtoPackage, supportPackage, concreteTypes)( - name, - fieldName, - prop, - meta, - requirement, - isCustomType, - defaultValue - ) - def renderDTOClass(clsName: String, supportPackage: List[String], terms: List[ProtocolParameter[L]], parents: List[SuperClass[L]] = Nil) = - newRenderDTOClass(clsName, supportPackage, terms, parents) - def encodeModel( - clsName: String, - dtoPackage: List[String], - params: List[ProtocolParameter[L]], - parents: List[SuperClass[L]] = Nil - ) = - newEncodeModel(clsName, dtoPackage, params, parents) - def decodeModel( - clsName: String, - dtoPackage: List[String], - supportPackage: List[String], - params: List[ProtocolParameter[L]], - parents: List[SuperClass[L]] = Nil - ) = - newDecodeModel(clsName, dtoPackage, supportPackage, params, parents) - - def renderDTOStaticDefns( - clsName: String, - deps: List[L#TermName], - encoder: Option[L#ValueDefinition], - decoder: Option[L#ValueDefinition], - params: List[ProtocolParameter[L]] - ) = - newRenderDTOStaticDefns(clsName, deps, encoder, decoder, params) - - def extractArrayType(arr: ResolvedType[L], concreteTypes: List[PropMeta[L]]) = newExtractArrayType(arr, concreteTypes) - - def extractSuperClass(swagger: Tracker[ComposedSchema], definitions: List[(String, Tracker[Schema[_]])]) = newExtractSuperClass(swagger, definitions) - def renderSealedTrait( - className: String, - params: List[ProtocolParameter[L]], - discriminator: Discriminator[L], - parents: List[SuperClass[L]] = Nil, - children: List[String] = Nil - ) = newRenderSealedTrait(className, params, discriminator, parents, children) - def encodeADT(clsName: String, discriminator: Discriminator[L], children: List[String] = Nil) = newEncodeADT(clsName, discriminator, children) - def decodeADT(clsName: String, discriminator: Discriminator[L], children: List[String] = Nil) = newDecodeADT(clsName, discriminator, children) - def renderADTStaticDefns(clsName: String, discriminator: Discriminator[L], encoder: Option[L#ValueDefinition], decoder: Option[L#ValueDefinition]) = - newRenderADTStaticDefns(clsName, discriminator, encoder, decoder) - - def renderMembers(clsName: String, elems: RenderedEnum[L]) = newRenderMembers(clsName, elems) - def encodeEnum(clsName: String, tpe: L#Type): F[Option[L#Definition]] = newEncodeEnum(clsName, tpe) - def decodeEnum(clsName: String, tpe: L#Type): F[Option[L#Definition]] = newDecodeEnum(clsName, tpe) - def renderClass(clsName: String, tpe: L#Type, elems: RenderedEnum[L]) = newRenderClass(clsName, tpe, elems) - def renderStaticDefns( - clsName: String, - tpe: L#Type, - members: Option[L#ObjectDefinition], - accessors: List[L#TermName], - encoder: Option[L#Definition], - decoder: Option[L#Definition] - ): F[StaticDefns[L]] = newRenderStaticDefns(clsName, tpe, members, accessors, encoder, decoder) - def buildAccessor(clsName: String, termName: String) = newBuildAccessor(clsName, termName) - } - } } diff --git a/modules/core/src/main/scala/dev/guardrail/terms/package.scala b/modules/core/src/main/scala/dev/guardrail/terms/package.scala deleted file mode 100644 index fb6392a994..0000000000 --- a/modules/core/src/main/scala/dev/guardrail/terms/package.scala +++ /dev/null @@ -1,3 +0,0 @@ -package dev.guardrail - -package object terms extends CollectionsSyntax diff --git a/modules/java-support/src/main/scala/dev/guardrail/generators/java/jackson/JacksonGenerator.scala b/modules/java-support/src/main/scala/dev/guardrail/generators/java/jackson/JacksonGenerator.scala index 8886124d2d..dd29a710dc 100644 --- a/modules/java-support/src/main/scala/dev/guardrail/generators/java/jackson/JacksonGenerator.scala +++ b/modules/java-support/src/main/scala/dev/guardrail/generators/java/jackson/JacksonGenerator.scala @@ -1,35 +1,68 @@ package dev.guardrail.generators.java.jackson +import _root_.io.swagger.v3.oas.models.{ Components, OpenAPI } import _root_.io.swagger.v3.oas.models.media.{ Discriminator => _, _ } -import cats.{ FlatMap, Monad } + +import cats.Foldable import cats.data.NonEmptyList import cats.syntax.all._ +import cats.{ FlatMap, Monad } + import com.github.javaparser.StaticJavaParser -import com.github.javaparser.ast.`type`.{ ClassOrInterfaceType, PrimitiveType, Type, UnknownType } import com.github.javaparser.ast.Modifier.Keyword.{ FINAL, PRIVATE, PROTECTED, PUBLIC } import com.github.javaparser.ast.Modifier._ -import com.github.javaparser.ast.{ Node, NodeList } +import com.github.javaparser.ast.`type`.{ ClassOrInterfaceType, PrimitiveType, Type, UnknownType } import com.github.javaparser.ast.body._ import com.github.javaparser.ast.expr.{ MethodCallExpr, _ } import com.github.javaparser.ast.stmt._ - +import com.github.javaparser.ast.{ Node, NodeList } +import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.typeTag + +import dev.guardrail.SwaggerUtil import dev.guardrail.core -import dev.guardrail.core.{ LiteralRawType, ReifiedRawType, Tracker } -import dev.guardrail.core.extract.{ DataRedaction, EmptyValueIsNull } +import dev.guardrail.core.extract.{ DataRedaction, Default, EmptyValueIsNull } import dev.guardrail.core.implicits._ -import dev.guardrail.core.{ DataRedacted, DataVisible, EmptyIsEmpty, EmptyIsNull, EmptyToNullBehaviour, RedactionBehaviour } +import dev.guardrail.core.{ + DataRedacted, + DataVisible, + EmptyIsEmpty, + EmptyIsNull, + EmptyToNullBehaviour, + LiteralRawType, + Mappish, + RedactionBehaviour, + ReifiedRawType, + Tracker +} +import dev.guardrail.generators.ProtocolDefinitions +import dev.guardrail.generators.ProtocolGenerator.{ WrapEnumSchema, wrapNumberEnumSchema, wrapObjectEnumSchema, wrapStringEnumSchema } +import dev.guardrail.generators.RawParameterName import dev.guardrail.generators.java.JavaCollectionsGenerator import dev.guardrail.generators.java.JavaGenerator import dev.guardrail.generators.java.JavaLanguage import dev.guardrail.generators.java.JavaVavrCollectionsGenerator import dev.guardrail.generators.java.syntax._ +import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassParent } import dev.guardrail.generators.spi.{ CollectionsGeneratorLoader, ModuleLoadResult, ProtocolGeneratorLoader } -import dev.guardrail.generators.RawParameterName import dev.guardrail.terms.collections.{ CollectionsAbstraction, JavaStdLibCollections, JavaVavrCollections } -import dev.guardrail.terms.protocol.PropertyRequirement +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol.{ Discriminator, EnumDefinition, PropertyRequirement } import dev.guardrail.terms.protocol._ -import dev.guardrail.terms.{ CollectionsLibTerms, ProtocolTerms, RenderedEnum, RenderedIntEnum, RenderedLongEnum, RenderedStringEnum } +import dev.guardrail.terms.{ + CollectionsLibTerms, + HeldEnum, + IntHeldEnum, + LanguageTerms, + LongHeldEnum, + ProtocolTerms, + RenderedEnum, + RenderedIntEnum, + RenderedLongEnum, + RenderedStringEnum, + StringHeldEnum, + SwaggerTerms +} import dev.guardrail.{ RuntimeFailure, Target, UserError } class JacksonProtocolGeneratorLoader extends ProtocolGeneratorLoader { @@ -56,6 +89,8 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T override implicit def MonadF: Monad[Target] = Target.targetInstances + import Target.targetInstances // TODO: Remove me. This resolves implicit ambiguity from MonadChain + private val BUILDER_TYPE = StaticJavaParser.parseClassOrInterfaceType("Builder") private val BIG_INTEGER_FQ_TYPE = StaticJavaParser.parseClassOrInterfaceType("java.math.BigInteger") private val BIG_DECIMAL_FQ_TYPE = StaticJavaParser.parseClassOrInterfaceType("java.math.BigDecimal") @@ -71,6 +106,787 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T emptyToNull: EmptyToNullBehaviour ) + private def groupHierarchies( + definitions: Mappish[List, String, Tracker[Schema[_]]] + )(implicit + Sc: LanguageTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[(List[ClassParent[JavaLanguage]], List[(String, Tracker[Schema[_]])])] = { + + def firstInHierarchy(model: Tracker[Schema[_]]): Option[Tracker[ObjectSchema]] = + model + .refine { case x: ComposedSchema => x } { elem => + definitions.value + .collectFirst { + case (clsName, element) + if elem.downField("allOf", _.getAllOf).exists(_.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$clsName"))) => + element + } + .flatMap( + _.refine { case x: ComposedSchema => x }(firstInHierarchy) + .orRefine { case o: ObjectSchema => o }(x => Option(x)) + .getOrElse(None) + ) + } + .getOrElse(None) + + def children(cls: String): List[ClassChild[JavaLanguage]] = definitions.value.flatMap { case (clsName, comp) => + comp + .refine { case x: ComposedSchema => x }(comp => + if ( + comp + .downField("allOf", _.getAllOf()) + .exists(x => x.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$cls"))) + ) { + Some(ClassChild(clsName, comp, children(clsName), getRequiredFieldsRec(comp))) + } else None + ) + .getOrElse(None) + } + + def classHierarchy(cls: String, model: Tracker[Schema[_]]): Target[Option[ClassParent[JavaLanguage]]] = + model + .refine { case c: ComposedSchema => c }(c => + firstInHierarchy(c) + .fold(Option.empty[Discriminator[JavaLanguage]].pure[Target])(Discriminator.fromSchema[JavaLanguage, Target]) + .map(_.map((_, getRequiredFieldsRec(c)))) + ) + .orRefine { case x: Schema[_] => x }(m => Discriminator.fromSchema(m).map(_.map((_, getRequiredFieldsRec(m))))) + .getOrElse(Option.empty[(Discriminator[JavaLanguage], List[String])].pure[Target]) + .map(_.map { case (discriminator, reqFields) => ClassParent(cls, model, children(cls), discriminator, reqFields) }) + + Sw.log.function("groupHierarchies")( + definitions.value + .traverse { case (cls, model) => + for { + hierarchy <- classHierarchy(cls, model) + } yield hierarchy.filterNot(_.children.isEmpty).toLeft((cls, model)) + } + .map(_.partitionEither[ClassParent[JavaLanguage], (String, Tracker[Schema[_]])](identity)) + ) + } + + private[this] def getRequiredFieldsRec(root: Tracker[Schema[_]]): List[String] = { + @scala.annotation.tailrec + def work(values: List[Tracker[Schema[_]]], acc: List[String]): List[String] = { + val required: List[String] = values.flatMap(_.downField("required", _.getRequired()).unwrapTracker) + val next: List[Tracker[Schema[_]]] = + for { + a <- values + b <- a.refine { case x: ComposedSchema => x }(_.downField("allOf", _.getAllOf())).toOption.toList + c <- b.indexedDistribute + } yield c + + val newRequired = acc ++ required + + next match { + case next @ (_ :: _) => work(next, newRequired) + case Nil => newRequired + } + } + work(List(root), Nil) + } + + private[this] def fromEnum[A]( + clsName: String, + schema: Tracker[Schema[A]], + dtoPackage: List[String], + components: Tracker[Option[Components]] + )(implicit + P: ProtocolTerms[JavaLanguage, Target], + F: FrameworkTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target], + wrapEnumSchema: WrapEnumSchema[A] + ): Target[Either[String, EnumDefinition[JavaLanguage]]] = { + import Sc._ + import Sw._ + + def validProg(held: HeldEnum, tpe: Type, fullType: Type): Target[EnumDefinition[JavaLanguage]] = + for { + (pascalValues, wrappedValues) <- held match { + case StringHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(elem) + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedStringEnum[JavaLanguage](elems) + } yield (pascalValues, wrappedValues) + case IntHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedIntEnum[JavaLanguage](elems) + } yield (pascalValues, wrappedValues) + case LongHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedLongEnum[JavaLanguage](elems) + } yield (pascalValues, wrappedValues) + } + members <- renderMembers(clsName, wrappedValues) + encoder <- encodeEnum(clsName, tpe) + decoder <- decodeEnum(clsName, tpe) + + defn <- renderClass(clsName, tpe, wrappedValues) + staticDefns <- renderStaticDefns(clsName, tpe, members, pascalValues, encoder, decoder) + classType <- pureTypeName(clsName) + } yield EnumDefinition[JavaLanguage](clsName, classType, fullType, wrappedValues, defn, staticDefns) + + for { + enum <- extractEnum(schema.map(wrapEnumSchema)) + customTpeName <- SwaggerUtil.customTypeName(schema) + (tpe, _) <- SwaggerUtil.determineTypeName(schema, Tracker.cloneHistory(schema, customTpeName), components) + fullType <- selectType(NonEmptyList.ofInitLast(dtoPackage, clsName)) + res <- enum.traverse(validProg(_, tpe, fullType)) + } yield res + } + + private[this] def getPropertyRequirement( + schema: Tracker[Schema[_]], + isRequired: Boolean, + defaultPropertyRequirement: PropertyRequirement + ): PropertyRequirement = + (for { + isNullable <- schema.downField("nullable", _.getNullable) + } yield (isRequired, isNullable) match { + case (true, None) => PropertyRequirement.Required + case (true, Some(false)) => PropertyRequirement.Required + case (true, Some(true)) => PropertyRequirement.RequiredNullable + case (false, None) => defaultPropertyRequirement + case (false, Some(false)) => PropertyRequirement.Optional + case (false, Some(true)) => PropertyRequirement.OptionalNullable + }).unwrapTracker + + private def defaultValue( + name: NonEmptyList[String], + schema: Tracker[Schema[_]], + requirement: PropertyRequirement, + definitions: List[(String, Tracker[Schema[_]])] + )(implicit + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target] + ): Target[Option[Node]] = { + import Sc._ + import Cl._ + val empty = Option.empty[Node].pure[Target] + schema.downField("$ref", _.get$ref()).indexedDistribute match { + case Some(ref) => + definitions + .collectFirst { + case (cls, refSchema) if ref.unwrapTracker.endsWith(s"/$cls") => + defaultValue(NonEmptyList.of(cls), refSchema, requirement, definitions) + } + .getOrElse(empty) + case None => + schema + .refine { case map: MapSchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => map }(map => + for { + customTpe <- SwaggerUtil.customMapTypeName(map) + result <- customTpe.fold(emptyMap().map(Option(_)))(_ => empty) + } yield result + ) + .orRefine { case arr: ArraySchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => arr }( + arr => + for { + customTpe <- SwaggerUtil.customArrayTypeName(arr) + result <- customTpe.fold(emptyArray().map(Option(_)))(_ => empty) + } yield result + ) + .orRefine { case p: BooleanSchema => p }(p => Default(p).extract[Boolean].fold(empty)(litBoolean(_).map(Some(_)))) + .orRefine { case p: NumberSchema if p.getFormat == "double" => p }(p => Default(p).extract[Double].fold(empty)(litDouble(_).map(Some(_)))) + .orRefine { case p: NumberSchema if p.getFormat == "float" => p }(p => Default(p).extract[Float].fold(empty)(litFloat(_).map(Some(_)))) + .orRefine { case p: IntegerSchema if p.getFormat == "int32" => p }(p => Default(p).extract[Int].fold(empty)(litInt(_).map(Some(_)))) + .orRefine { case p: IntegerSchema if p.getFormat == "int64" => p }(p => Default(p).extract[Long].fold(empty)(litLong(_).map(Some(_)))) + .orRefine { case p: StringSchema if Option(p.getEnum).map(_.asScala).exists(_.nonEmpty) => p }(p => + Default(p).extract[String] match { + case Some(defaultEnumValue) => + for { + enumName <- formatEnumName(defaultEnumValue) + result <- selectTerm(name.append(enumName)) + } yield Some(result) + case None => empty + } + ) + .orRefine { case p: StringSchema => p }(p => Default(p).extract[String].fold(empty)(litString(_).map(Some(_)))) + .getOrElse(empty) + } + } + + /** Handle polymorphic model + */ + private[this] def fromPoly( + hierarchy: ClassParent[JavaLanguage], + concreteTypes: List[PropMeta[JavaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[JavaLanguage, Target], + P: ProtocolTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[ProtocolElems[JavaLanguage]] = { + import Sc._ + + def child(hierarchy: ClassHierarchy[JavaLanguage]): List[String] = + hierarchy.children.map(_.name) ::: hierarchy.children.flatMap(child) + def parent(hierarchy: ClassHierarchy[JavaLanguage]): List[String] = + if (hierarchy.children.nonEmpty) hierarchy.name :: hierarchy.children.flatMap(parent) + else Nil + + val children = child(hierarchy).diff(parent(hierarchy)).distinct + val discriminator = hierarchy.discriminator + + for { + parents <- hierarchy.model + .refine[Target[List[SuperClass[JavaLanguage]]]] { case c: ComposedSchema => c }( + extractParents(_, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) + ) + .getOrElse(List.empty[SuperClass[JavaLanguage]].pure[Target]) + props <- extractProperties(hierarchy.model) + requiredFields = hierarchy.required ::: hierarchy.children.flatMap(_.required) + params <- props.traverse { case (name, prop) => + for { + typeName <- formatTypeName(name).map(formattedName => NonEmptyList.of(hierarchy.name, formattedName)) + propertyRequirement = getPropertyRequirement(prop, requiredFields.contains(name), defaultPropertyRequirement) + customType <- SwaggerUtil.customTypeName(prop) + resolvedType <- SwaggerUtil + .propMeta[JavaLanguage, Target]( + prop, + components + ) // TODO: This should be resolved via an alternate mechanism that maintains references all the way through, instead of re-deriving and assuming that references are valid + defValue <- defaultValue(typeName, prop, propertyRequirement, definitions) + fieldName <- formatFieldName(name) + res <- transformProperty(hierarchy.name, dtoPackage, supportPackage, concreteTypes)( + name, + fieldName, + prop, + resolvedType, + propertyRequirement, + customType.isDefined, + defValue + ) + } yield res + } + definition <- renderSealedTrait(hierarchy.name, params, discriminator, parents, children) + encoder <- encodeADT(hierarchy.name, hierarchy.discriminator, children) + decoder <- decodeADT(hierarchy.name, hierarchy.discriminator, children) + staticDefns <- renderADTStaticDefns(hierarchy.name, discriminator, encoder, decoder) + tpe <- pureTypeName(hierarchy.name) + fullType <- selectType(NonEmptyList.fromList(dtoPackage :+ hierarchy.name).getOrElse(NonEmptyList.of(hierarchy.name))) + } yield ADT[JavaLanguage]( + name = hierarchy.name, + tpe = tpe, + fullType = fullType, + trt = definition, + staticDefns = staticDefns + ) + } + + private def prepareProperties( + clsName: NonEmptyList[String], + propertyToTypeLookup: Map[String, String], + props: List[(String, Tracker[Schema[_]])], + requiredFields: List[String], + concreteTypes: List[PropMeta[JavaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[JavaLanguage, Target], + P: ProtocolTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[(List[ProtocolParameter[JavaLanguage]], List[NestedProtocolElems[JavaLanguage]])] = { + import Sc._ + def getClsName(name: String): NonEmptyList[String] = propertyToTypeLookup.get(name).map(NonEmptyList.of(_)).getOrElse(clsName) + + def processProperty(name: String, schema: Tracker[Schema[_]]): Target[Option[Either[String, NestedProtocolElems[JavaLanguage]]]] = + for { + nestedClassName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) + defn <- schema + .refine[Target[Option[Either[String, NestedProtocolElems[JavaLanguage]]]]] { case x: ObjectSchema => x }(o => + for { + defn <- fromModel( + nestedClassName, + o, + List.empty, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + } yield Option(defn) + ) + .orRefine { case o: ComposedSchema => o }(o => + for { + parents <- extractParents(o, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) + maybeClassDefinition <- fromModel( + nestedClassName, + o, + parents, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + } yield Option(maybeClassDefinition) + ) + .orRefine { case a: ArraySchema => a }(_.downField("items", _.getItems()).indexedCosequence.flatTraverse(processProperty(name, _))) + .orRefine { case s: StringSchema if Option(s.getEnum).map(_.asScala).exists(_.nonEmpty) => s }(s => + fromEnum[String](nestedClassName.last, s, dtoPackage, components).map(Option(_)) + ) + .getOrElse(Option.empty[Either[String, NestedProtocolElems[JavaLanguage]]].pure[Target]) + } yield defn + + for { + paramsAndNestedDefinitions <- props.traverse[Target, (Tracker[ProtocolParameter[JavaLanguage]], Option[NestedProtocolElems[JavaLanguage]])] { + case (name, schema) => + for { + typeName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) + tpe <- selectType(typeName) + maybeNestedDefinition <- processProperty(name, schema) + resolvedType <- SwaggerUtil.propMetaWithName(tpe, schema, components) + customType <- SwaggerUtil.customTypeName(schema) + propertyRequirement = getPropertyRequirement(schema, requiredFields.contains(name), defaultPropertyRequirement) + defValue <- defaultValue(typeName, schema, propertyRequirement, definitions) + fieldName <- formatFieldName(name) + parameter <- transformProperty(getClsName(name).last, dtoPackage, supportPackage, concreteTypes)( + name, + fieldName, + schema, + resolvedType, + propertyRequirement, + customType.isDefined, + defValue + ) + } yield (Tracker.cloneHistory(schema, parameter), maybeNestedDefinition.flatMap(_.toOption)) + } + (params, nestedDefinitions) = paramsAndNestedDefinitions.unzip + deduplicatedParams <- deduplicateParams(params) + unconflictedParams <- fixConflictingNames(deduplicatedParams) + } yield (unconflictedParams, nestedDefinitions.flatten) + } + + private def deduplicateParams( + params: List[Tracker[ProtocolParameter[JavaLanguage]]] + )(implicit Sw: SwaggerTerms[JavaLanguage, Target], Sc: LanguageTerms[JavaLanguage, Target]): Target[List[ProtocolParameter[JavaLanguage]]] = { + import Sc._ + Foldable[List] + .foldLeftM[Target, Tracker[ProtocolParameter[JavaLanguage]], List[ProtocolParameter[JavaLanguage]]](params, List.empty[ProtocolParameter[JavaLanguage]]) { + (s, ta) => + val a = ta.unwrapTracker + s.find(p => p.name == a.name) match { + case None => (a :: s).pure[Target] + case Some(duplicate) => + for { + newDefaultValue <- findCommonDefaultValue(ta.showHistory, a.defaultValue, duplicate.defaultValue) + newRawType <- findCommonRawType(ta.showHistory, a.rawType, duplicate.rawType) + } yield { + val emptyToNull = if (Set(a.emptyToNull, duplicate.emptyToNull).contains(EmptyIsNull)) EmptyIsNull else EmptyIsEmpty + val redactionBehaviour = if (Set(a.dataRedaction, duplicate.dataRedaction).contains(DataRedacted)) DataRedacted else DataVisible + val mergedParameter = ProtocolParameter[JavaLanguage]( + a.term, + a.baseType, + a.name, + a.dep, + newRawType, + a.readOnlyKey.orElse(duplicate.readOnlyKey), + emptyToNull, + redactionBehaviour, + a.propertyRequirement, + newDefaultValue, + a.propertyValidation + ) + mergedParameter :: s.filter(_.name != a.name) + } + } + } + .map(_.reverse) + } + + private def fixConflictingNames( + params: List[ProtocolParameter[JavaLanguage]] + )(implicit Lt: LanguageTerms[JavaLanguage, Target]): Target[List[ProtocolParameter[JavaLanguage]]] = { + import Lt._ + for { + paramsWithNames <- params.traverse(param => extractTermNameFromParam(param.term).map((_, param))) + counts = paramsWithNames.groupBy(_._1).view.mapValues(_.length).toMap + newParams <- paramsWithNames.traverse { case (name, param) => + if (counts.getOrElse(name, 0) > 1) { + for { + newTermName <- pureTermName(param.name.value) + newMethodParam <- alterMethodParameterName(param.term, newTermName) + } yield ProtocolParameter( + newMethodParam, + param.baseType, + param.name, + param.dep, + param.rawType, + param.readOnlyKey, + param.emptyToNull, + param.dataRedaction, + param.propertyRequirement, + param.defaultValue, + param.propertyValidation + ) + } else { + param.pure[Target] + } + } + } yield newParams + } + + private def modelTypeAlias(clsName: String, abstractModel: Tracker[Schema[_]], components: Tracker[Option[Components]])(implicit + Fw: FrameworkTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[ProtocolElems[JavaLanguage]] = { + import Fw._ + val model: Option[Tracker[ObjectSchema]] = abstractModel + .refine[Option[Tracker[ObjectSchema]]] { case m: ObjectSchema => m }(x => Option(x)) + .orRefine { case m: ComposedSchema => m }( + _.downField("allOf", _.getAllOf()).indexedCosequence + .get(1) + .flatMap( + _.refine { case o: ObjectSchema => o }(Option.apply) + .orRefineFallback(_ => None) + ) + ) + .orRefineFallback(_ => None) + for { + tpe <- model.fold[Target[Type]](objectType(None)) { m => + for { + tpeName <- SwaggerUtil.customTypeName[JavaLanguage, Target, Tracker[ObjectSchema]](m) + (declType, _) <- SwaggerUtil.determineTypeName[JavaLanguage, Target](m, Tracker.cloneHistory(m, tpeName), components) + } yield declType + } + res <- typeAlias(clsName, tpe) + } yield res + } + + private def plainTypeAlias( + clsName: String + )(implicit Fw: FrameworkTerms[JavaLanguage, Target], Sc: LanguageTerms[JavaLanguage, Target]): Target[ProtocolElems[JavaLanguage]] = { + import Fw._ + for { + tpe <- objectType(None) + res <- typeAlias(clsName, tpe) + } yield res + } + + private def typeAlias(clsName: String, tpe: Type): Target[ProtocolElems[JavaLanguage]] = + (RandomType[JavaLanguage](clsName, tpe): ProtocolElems[JavaLanguage]).pure[Target] + + private def fromArray(clsName: String, arr: Tracker[ArraySchema], concreteTypes: List[PropMeta[JavaLanguage]], components: Tracker[Option[Components]])( + implicit + F: FrameworkTerms[JavaLanguage, Target], + P: ProtocolTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[ProtocolElems[JavaLanguage]] = + for { + deferredTpe <- SwaggerUtil.modelMetaType(arr, components) + tpe <- extractArrayType(deferredTpe, concreteTypes) + ret <- typeAlias(clsName, tpe) + } yield ret + + private def extractParents( + elem: Tracker[ComposedSchema], + definitions: List[(String, Tracker[Schema[_]])], + concreteTypes: List[PropMeta[JavaLanguage]], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[JavaLanguage, Target], + P: ProtocolTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[List[SuperClass[JavaLanguage]]] = { + import Sc._ + + for { + a <- extractSuperClass(elem, definitions) + supper <- a.flatTraverse { case (clsName, _extends, interfaces) => + val concreteInterfacesWithClass = for { + interface <- interfaces + (cls, tracker) <- definitions + result <- tracker + .refine[Tracker[Schema[_]]] { + case x: ComposedSchema if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x + }( + identity _ + ) + .orRefine { case x: Schema[_] if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x }(identity _) + .toOption + } yield cls -> result + val (_, concreteInterfaces) = concreteInterfacesWithClass.unzip + val classMapping = (for { + (cls, schema) <- concreteInterfacesWithClass + (name, _) <- schema.downField("properties", _.getProperties).indexedDistribute.value + } yield (name, cls)).toMap + for { + _extendsProps <- extractProperties(_extends) + requiredFields = getRequiredFieldsRec(_extends) ++ concreteInterfaces.flatMap(getRequiredFieldsRec) + _withProps <- concreteInterfaces.traverse(extractProperties) + props = _extendsProps ++ _withProps.flatten + (params, _) <- prepareProperties( + NonEmptyList.of(clsName), + classMapping, + props, + requiredFields, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + interfacesCls = interfaces.flatMap(_.downField("$ref", _.get$ref).unwrapTracker.map(_.split("/").last)) + tpe <- parseTypeName(clsName) + + discriminators <- (_extends :: concreteInterfaces).flatTraverse( + _.refine[Target[List[Discriminator[JavaLanguage]]]] { case m: ObjectSchema => m }(m => Discriminator.fromSchema(m).map(_.toList)) + .getOrElse(List.empty[Discriminator[JavaLanguage]].pure[Target]) + ) + } yield tpe + .map( + SuperClass[JavaLanguage]( + clsName, + _, + interfacesCls, + params, + discriminators + ) + ) + .toList + } + + } yield supper + } + + private[this] def fromModel( + clsName: NonEmptyList[String], + model: Tracker[Schema[_]], + parents: List[SuperClass[JavaLanguage]], + concreteTypes: List[PropMeta[JavaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[JavaLanguage, Target], + P: ProtocolTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[Either[String, ClassDefinition[JavaLanguage]]] = { + import Sc._ + + for { + props <- extractProperties(model) + requiredFields = getRequiredFieldsRec(model) + (params, nestedDefinitions) <- prepareProperties( + clsName, + Map.empty, + props, + requiredFields, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + encoder <- encodeModel(clsName.last, dtoPackage, params, parents) + decoder <- decodeModel(clsName.last, dtoPackage, supportPackage, params, parents) + tpe <- parseTypeName(clsName.last) + fullType <- selectType(dtoPackage.foldRight(clsName)((x, xs) => xs.prepend(x))) + staticDefns <- renderDTOStaticDefns(clsName.last, List.empty, encoder, decoder, params) + nestedClasses <- nestedDefinitions.flatTraverse { + case classDefinition: ClassDefinition[JavaLanguage] => + for { + widenClass <- widenClassDefinition(classDefinition.cls) + companionTerm <- pureTermName(classDefinition.name) + companionDefinition <- wrapToObject(companionTerm, classDefinition.staticDefns.extraImports, classDefinition.staticDefns.definitions) + widenCompanion <- companionDefinition.traverse(widenObjectDefinition) + } yield List(widenClass) ++ widenCompanion.fold(classDefinition.staticDefns.definitions)(List(_)) + case enumDefinition: EnumDefinition[JavaLanguage] => + for { + widenClass <- widenClassDefinition(enumDefinition.cls) + companionTerm <- pureTermName(enumDefinition.name) + companionDefinition <- wrapToObject(companionTerm, enumDefinition.staticDefns.extraImports, enumDefinition.staticDefns.definitions) + widenCompanion <- companionDefinition.traverse(widenObjectDefinition) + } yield List(widenClass) ++ widenCompanion.fold(enumDefinition.staticDefns.definitions)(List(_)) + } + defn <- renderDTOClass(clsName.last, supportPackage, params, parents) + } yield { + val finalStaticDefns = staticDefns.copy(definitions = staticDefns.definitions ++ nestedClasses) + if (parents.isEmpty && props.isEmpty) Left("Entity isn't model"): Either[String, ClassDefinition[JavaLanguage]] + else tpe.toRight("Empty entity name").map(ClassDefinition[JavaLanguage](clsName.last, _, fullType, defn, finalStaticDefns, parents)) + } + } + + override def fromSwagger( + swagger: Tracker[OpenAPI], + dtoPackage: List[String], + supportPackage: NonEmptyList[String], + defaultPropertyRequirement: PropertyRequirement + )(implicit + F: FrameworkTerms[JavaLanguage, Target], + P: ProtocolTerms[JavaLanguage, Target], + Sc: LanguageTerms[JavaLanguage, Target], + Cl: CollectionsLibTerms[JavaLanguage, Target], + Sw: SwaggerTerms[JavaLanguage, Target] + ): Target[ProtocolDefinitions[JavaLanguage]] = { + import Sc._ + + val components = swagger.downField("components", _.getComponents()) + val definitions = components.flatDownField("schemas", _.getSchemas()).indexedCosequence + Sw.log.function("ProtocolGenerator.fromSwagger")(for { + (hierarchies, definitionsWithoutPoly) <- groupHierarchies(definitions) + + concreteTypes <- SwaggerUtil.extractConcreteTypes[JavaLanguage, Target](definitions.value, components) + polyADTs <- hierarchies.traverse(fromPoly(_, concreteTypes, definitions.value, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components)) + elems <- definitionsWithoutPoly.traverse { case (clsName, model) => + model + .refine { case c: ComposedSchema => c }(comp => + for { + formattedClsName <- formatTypeName(clsName) + parents <- extractParents(comp, definitions.value, concreteTypes, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components) + model <- fromModel( + clsName = NonEmptyList.of(formattedClsName), + model = comp, + parents = parents, + concreteTypes = concreteTypes, + definitions = definitions.value, + dtoPackage = dtoPackage, + supportPackage = supportPackage.toList, + defaultPropertyRequirement = defaultPropertyRequirement, + components = components + ) + alias <- modelTypeAlias(formattedClsName, comp, components) + } yield model.getOrElse(alias) + ) + .orRefine { case a: ArraySchema => a }(arr => + for { + formattedClsName <- formatTypeName(clsName) + array <- fromArray(formattedClsName, arr, concreteTypes, components) + } yield array + ) + .orRefine { case o: ObjectSchema => o }(m => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum[Object](formattedClsName, m, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + m, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + alias <- modelTypeAlias(formattedClsName, m, components) + } yield enum.orElse(model).getOrElse(alias) + ) + .orRefine { case x: StringSchema => x }(x => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum(formattedClsName, x, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + x, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[JavaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + alias <- typeAlias(formattedClsName, declType) + } yield enum.orElse(model).getOrElse(alias) + ) + .orRefine { case x: IntegerSchema => x }(x => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum(formattedClsName, x, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + x, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[JavaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + alias <- typeAlias(formattedClsName, declType) + } yield enum.orElse(model).getOrElse(alias) + ) + .valueOr(x => + for { + formattedClsName <- formatTypeName(clsName) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[JavaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + res <- typeAlias(formattedClsName, declType) + } yield res + ) + } + protoImports <- protocolImports() + pkgImports <- packageObjectImports() + pkgObjectContents <- packageObjectContents() + implicitsObject <- implicitsObject() + + polyADTElems <- ProtocolElems.resolve[JavaLanguage, Target](polyADTs) + strictElems <- ProtocolElems.resolve[JavaLanguage, Target](elems) + } yield ProtocolDefinitions[JavaLanguage](strictElems ++ polyADTElems, protoImports, pkgImports, pkgObjectContents, implicitsObject)) + } + // returns a tuple of (requiredTerms, optionalTerms) // note that required terms _that have a default value_ are conceptually optional. private def sortParams( @@ -214,19 +1030,19 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T ) } - override def renderMembers( + private def renderMembers( clsName: String, elems: RenderedEnum[JavaLanguage] ) = Target.pure(None) - override def encodeEnum(clsName: String, tpe: com.github.javaparser.ast.`type`.Type): Target[Option[BodyDeclaration[_ <: BodyDeclaration[_]]]] = + private def encodeEnum(clsName: String, tpe: com.github.javaparser.ast.`type`.Type): Target[Option[BodyDeclaration[_ <: BodyDeclaration[_]]]] = Target.pure(None) - override def decodeEnum(clsName: String, tpe: com.github.javaparser.ast.`type`.Type): Target[Option[BodyDeclaration[_ <: BodyDeclaration[_]]]] = + private def decodeEnum(clsName: String, tpe: com.github.javaparser.ast.`type`.Type): Target[Option[BodyDeclaration[_ <: BodyDeclaration[_]]]] = Target.pure(None) - override def renderClass( + private def renderClass( clsName: String, tpe: com.github.javaparser.ast.`type`.Type, elems: RenderedEnum[JavaLanguage] @@ -382,7 +1198,7 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T Target.pure(enumClass) } - override def renderStaticDefns( + private def renderStaticDefns( clsName: String, tpe: com.github.javaparser.ast.`type`.Type, members: Option[Nothing], @@ -404,7 +1220,7 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T override def buildAccessor(clsName: String, termName: String) = Target.pure(new Name(s"${clsName}.${termName}")) - override def renderDTOClass( + private def renderDTOClass( clsName: String, supportPackage: List[String], selfParams: List[ProtocolParameter[JavaLanguage]], @@ -820,7 +1636,7 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T } yield dtoClass } - override def extractProperties(swagger: Tracker[Schema[_]]) = + private def extractProperties(swagger: Tracker[Schema[_]]) = swagger .refine[Target[List[(String, Tracker[Schema[_]])]]] { case m: ObjectSchema => m }(m => Target.pure(m.downField("properties", _.getProperties()).indexedCosequence.value) @@ -840,7 +1656,7 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T ) .getOrElse(Target.pure(List.empty[(String, Tracker[Schema[_]])])) - override def transformProperty( + private def transformProperty( clsName: String, dtoPackage: List[String], supportPackage: List[String], @@ -929,7 +1745,7 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T ) } - override def encodeModel( + private def encodeModel( clsName: String, dtoPackage: List[String], selfParams: List[ProtocolParameter[JavaLanguage]], @@ -937,7 +1753,7 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T ) = Target.pure(None) - override def decodeModel( + private def decodeModel( clsName: String, dtoPackage: List[String], supportPackage: List[String], @@ -946,16 +1762,16 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T ) = Target.pure(None) - override def renderDTOStaticDefns( + private def renderDTOStaticDefns( clsName: String, deps: List[com.github.javaparser.ast.expr.Name], encoder: Option[com.github.javaparser.ast.body.VariableDeclarator], decoder: Option[com.github.javaparser.ast.body.VariableDeclarator], protocolParameters: List[ProtocolParameter[JavaLanguage]] ) = - Target.pure(StaticDefns(clsName, List.empty, List.empty)) + Target.pure(StaticDefns[JavaLanguage](clsName, List.empty, List.empty)) - override def extractArrayType( + private def extractArrayType( arr: core.ResolvedType[JavaLanguage], concreteTypes: List[PropMeta[JavaLanguage]] ): Target[Type] = @@ -974,10 +1790,10 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T } } yield result - override def extractConcreteTypes(definitions: Either[String, List[PropMeta[JavaLanguage]]]) = + private def extractConcreteTypes(definitions: Either[String, List[PropMeta[JavaLanguage]]]) = definitions.fold[Target[List[PropMeta[JavaLanguage]]]](Target.raiseUserError, Target.pure) - override def protocolImports() = + private def protocolImports() = (List( "com.fasterxml.jackson.annotation.JsonCreator", "com.fasterxml.jackson.annotation.JsonIgnoreProperties", @@ -993,16 +1809,16 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T override def generateSupportDefinitions() = Target.pure(List.empty) - override def packageObjectImports() = + private def packageObjectImports() = Target.pure(List.empty) - override def packageObjectContents() = + private def packageObjectContents() = Target.pure(List.empty) - override def implicitsObject() = + private def implicitsObject() = Target.pure(None) - override def renderSealedTrait( + private def renderSealedTrait( className: String, selfParams: List[ProtocolParameter[JavaLanguage]], discriminator: Discriminator[JavaLanguage], @@ -1108,7 +1924,7 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T } yield abstractClass } - override def extractSuperClass( + private def extractSuperClass( swagger: Tracker[ComposedSchema], definitions: List[(String, Tracker[Schema[_]])] ) = { @@ -1130,7 +1946,7 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T Target.pure(allParents(swagger)) } - override def renderADTStaticDefns( + private def renderADTStaticDefns( clsName: String, discriminator: Discriminator[JavaLanguage], encoder: Option[com.github.javaparser.ast.body.VariableDeclarator], @@ -1148,14 +1964,14 @@ class JacksonGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage, T List.empty ) - override def decodeADT( + private def decodeADT( clsName: String, discriminator: Discriminator[JavaLanguage], children: List[String] = Nil ) = Target.pure(None) - override def encodeADT( + private def encodeADT( clsName: String, discriminator: Discriminator[JavaLanguage], children: List[String] = Nil diff --git a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala index b73850ecc6..9988d8b5ed 100644 --- a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala +++ b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala @@ -1,23 +1,42 @@ package dev.guardrail.generators.scala.circe import _root_.io.swagger.v3.oas.models.media.{ Discriminator => _, _ } +import _root_.io.swagger.v3.oas.models.{ Components, OpenAPI } +import cats.Foldable import cats.Monad import cats.data.{ NonEmptyList, NonEmptyVector } import cats.syntax.all._ - +import scala.jdk.CollectionConverters._ import scala.meta.{ Defn, _ } import scala.reflect.runtime.universe.typeTag + import dev.guardrail.core -import dev.guardrail.core.extract.{ DataRedaction, EmptyValueIsNull } +import dev.guardrail.core.extract.{ DataRedaction, Default, EmptyValueIsNull } import dev.guardrail.core.implicits._ -import dev.guardrail.core.{ DataVisible, EmptyIsEmpty, EmptyIsNull, LiteralRawType, ReifiedRawType, ResolvedType, SupportDefinition, Tracker } -import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader } -import dev.guardrail.generators.scala.{ CirceModelGenerator, ScalaGenerator, ScalaLanguage } -import dev.guardrail.generators.RawParameterName +import dev.guardrail.core.{ DataRedacted, DataVisible, EmptyIsEmpty, EmptyIsNull, LiteralRawType, Mappish, ReifiedRawType, SupportDefinition, Tracker } +import dev.guardrail.generators.ProtocolGenerator.{ WrapEnumSchema, wrapNumberEnumSchema, wrapObjectEnumSchema, wrapStringEnumSchema } +import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassParent } import dev.guardrail.generators.scala.circe.CirceProtocolGenerator.WithValidations +import dev.guardrail.generators.scala.{ CirceModelGenerator, ScalaGenerator, ScalaLanguage } +import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader } +import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName } +import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.PropertyRequirement import dev.guardrail.terms.protocol._ -import dev.guardrail.terms.{ ProtocolTerms, RenderedEnum, RenderedIntEnum, RenderedLongEnum, RenderedStringEnum } +import dev.guardrail.terms.{ + CollectionsLibTerms, + HeldEnum, + IntHeldEnum, + LanguageTerms, + LongHeldEnum, + ProtocolTerms, + RenderedEnum, + RenderedIntEnum, + RenderedLongEnum, + RenderedStringEnum, + StringHeldEnum, + SwaggerTerms +} import dev.guardrail.{ SwaggerUtil, Target, UserError } class CirceProtocolGeneratorLoader extends ProtocolGeneratorLoader { @@ -46,6 +65,794 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa override implicit def MonadF: Monad[Target] = Target.targetInstances + import Target.targetInstances // TODO: Remove me. This resolves implicit ambiguity from MonadChain + + override def fromSwagger( + swagger: Tracker[OpenAPI], + dtoPackage: List[String], + supportPackage: NonEmptyList[String], + defaultPropertyRequirement: PropertyRequirement + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolDefinitions[ScalaLanguage]] = { + import Sc._ + + val components = swagger.downField("components", _.getComponents()) + val definitions = components.flatDownField("schemas", _.getSchemas()).indexedCosequence + Sw.log.function("ProtocolGenerator.fromSwagger")(for { + (hierarchies, definitionsWithoutPoly) <- groupHierarchies(definitions) + + concreteTypes <- SwaggerUtil.extractConcreteTypes[ScalaLanguage, Target](definitions.value, components) + polyADTs <- hierarchies.traverse(fromPoly(_, concreteTypes, definitions.value, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components)) + elems <- definitionsWithoutPoly.traverse { case (clsName, model) => + model + .refine { case c: ComposedSchema => c }(comp => + for { + formattedClsName <- formatTypeName(clsName) + parents <- extractParents(comp, definitions.value, concreteTypes, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components) + model <- fromModel( + clsName = NonEmptyList.of(formattedClsName), + model = comp, + parents = parents, + concreteTypes = concreteTypes, + definitions = definitions.value, + dtoPackage = dtoPackage, + supportPackage = supportPackage.toList, + defaultPropertyRequirement = defaultPropertyRequirement, + components = components + ) + alias <- modelTypeAlias(formattedClsName, comp, components) + } yield model.getOrElse(alias) + ) + .orRefine { case a: ArraySchema => a }(arr => + for { + formattedClsName <- formatTypeName(clsName) + array <- fromArray(formattedClsName, arr, concreteTypes, components) + } yield array + ) + .orRefine { case o: ObjectSchema => o }(m => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum[Object](formattedClsName, m, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + m, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + alias <- modelTypeAlias(formattedClsName, m, components) + } yield enum.orElse(model).getOrElse(alias) + ) + .orRefine { case x: StringSchema => x }(x => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum(formattedClsName, x, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + x, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + alias <- typeAlias(formattedClsName, declType) + } yield enum.orElse(model).getOrElse(alias) + ) + .orRefine { case x: IntegerSchema => x }(x => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum(formattedClsName, x, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + x, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + alias <- typeAlias(formattedClsName, declType) + } yield enum.orElse(model).getOrElse(alias) + ) + .valueOr(x => + for { + formattedClsName <- formatTypeName(clsName) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + res <- typeAlias(formattedClsName, declType) + } yield res + ) + } + protoImports <- protocolImports() + pkgImports <- packageObjectImports() + pkgObjectContents <- packageObjectContents() + implicitsObject <- implicitsObject() + + polyADTElems <- ProtocolElems.resolve[ScalaLanguage, Target](polyADTs) + strictElems <- ProtocolElems.resolve[ScalaLanguage, Target](elems) + } yield ProtocolDefinitions[ScalaLanguage](strictElems ++ polyADTElems, protoImports, pkgImports, pkgObjectContents, implicitsObject)) + } + + private[this] def getRequiredFieldsRec(root: Tracker[Schema[_]]): List[String] = { + @scala.annotation.tailrec + def work(values: List[Tracker[Schema[_]]], acc: List[String]): List[String] = { + val required: List[String] = values.flatMap(_.downField("required", _.getRequired()).unwrapTracker) + val next: List[Tracker[Schema[_]]] = + for { + a <- values + b <- a.refine { case x: ComposedSchema => x }(_.downField("allOf", _.getAllOf())).toOption.toList + c <- b.indexedDistribute + } yield c + + val newRequired = acc ++ required + + next match { + case next @ (_ :: _) => work(next, newRequired) + case Nil => newRequired + } + } + work(List(root), Nil) + } + + private[this] def fromEnum[A]( + clsName: String, + schema: Tracker[Schema[A]], + dtoPackage: List[String], + components: Tracker[Option[Components]] + )(implicit + P: ProtocolTerms[ScalaLanguage, Target], + F: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target], + wrapEnumSchema: WrapEnumSchema[A] + ): Target[Either[String, EnumDefinition[ScalaLanguage]]] = { + import Sc._ + import Sw._ + + def validProg(held: HeldEnum, tpe: scala.meta.Type, fullType: scala.meta.Type): Target[EnumDefinition[ScalaLanguage]] = + for { + (pascalValues, wrappedValues) <- held match { + case StringHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(elem) + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedStringEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + case IntHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedIntEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + case LongHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedLongEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + } + members <- renderMembers(clsName, wrappedValues) + encoder <- encodeEnum(clsName, tpe) + decoder <- decodeEnum(clsName, tpe) + + defn <- renderClass(clsName, tpe, wrappedValues) + staticDefns <- renderStaticDefns(clsName, tpe, members, pascalValues, encoder, decoder) + classType <- pureTypeName(clsName) + } yield EnumDefinition[ScalaLanguage](clsName, classType, fullType, wrappedValues, defn, staticDefns) + + for { + enum <- extractEnum(schema.map(wrapEnumSchema)) + customTpeName <- SwaggerUtil.customTypeName(schema) + (tpe, _) <- SwaggerUtil.determineTypeName(schema, Tracker.cloneHistory(schema, customTpeName), components) + fullType <- selectType(NonEmptyList.ofInitLast(dtoPackage, clsName)) + res <- enum.traverse(validProg(_, tpe, fullType)) + } yield res + } + + private[this] def getPropertyRequirement( + schema: Tracker[Schema[_]], + isRequired: Boolean, + defaultPropertyRequirement: PropertyRequirement + ): PropertyRequirement = + (for { + isNullable <- schema.downField("nullable", _.getNullable) + } yield (isRequired, isNullable) match { + case (true, None) => PropertyRequirement.Required + case (true, Some(false)) => PropertyRequirement.Required + case (true, Some(true)) => PropertyRequirement.RequiredNullable + case (false, None) => defaultPropertyRequirement + case (false, Some(false)) => PropertyRequirement.Optional + case (false, Some(true)) => PropertyRequirement.OptionalNullable + }).unwrapTracker + + /** Handle polymorphic model + */ + private[this] def fromPoly( + hierarchy: ClassParent[ScalaLanguage], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = { + import Sc._ + + def child(hierarchy: ClassHierarchy[ScalaLanguage]): List[String] = + hierarchy.children.map(_.name) ::: hierarchy.children.flatMap(child) + def parent(hierarchy: ClassHierarchy[ScalaLanguage]): List[String] = + if (hierarchy.children.nonEmpty) hierarchy.name :: hierarchy.children.flatMap(parent) + else Nil + + val children = child(hierarchy).diff(parent(hierarchy)).distinct + val discriminator = hierarchy.discriminator + + for { + parents <- hierarchy.model + .refine[Target[List[SuperClass[ScalaLanguage]]]] { case c: ComposedSchema => c }( + extractParents(_, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) + ) + .getOrElse(List.empty[SuperClass[ScalaLanguage]].pure[Target]) + props <- extractProperties(hierarchy.model) + requiredFields = hierarchy.required ::: hierarchy.children.flatMap(_.required) + params <- props.traverse { case (name, prop) => + for { + typeName <- formatTypeName(name).map(formattedName => NonEmptyList.of(hierarchy.name, formattedName)) + propertyRequirement = getPropertyRequirement(prop, requiredFields.contains(name), defaultPropertyRequirement) + customType <- SwaggerUtil.customTypeName(prop) + resolvedType <- SwaggerUtil + .propMeta[ScalaLanguage, Target]( + prop, + components + ) // TODO: This should be resolved via an alternate mechanism that maintains references all the way through, instead of re-deriving and assuming that references are valid + defValue <- defaultValue(typeName, prop, propertyRequirement, definitions) + fieldName <- formatFieldName(name) + res <- transformProperty(hierarchy.name, dtoPackage, supportPackage, concreteTypes)( + name, + fieldName, + prop, + resolvedType, + propertyRequirement, + customType.isDefined, + defValue + ) + } yield res + } + definition <- renderSealedTrait(hierarchy.name, params, discriminator, parents, children) + encoder <- encodeADT(hierarchy.name, hierarchy.discriminator, children) + decoder <- decodeADT(hierarchy.name, hierarchy.discriminator, children) + staticDefns <- renderADTStaticDefns(hierarchy.name, discriminator, encoder, decoder) + tpe <- pureTypeName(hierarchy.name) + fullType <- selectType(NonEmptyList.fromList(dtoPackage :+ hierarchy.name).getOrElse(NonEmptyList.of(hierarchy.name))) + } yield ADT[ScalaLanguage]( + name = hierarchy.name, + tpe = tpe, + fullType = fullType, + trt = definition, + staticDefns = staticDefns + ) + } + + private def extractParents( + elem: Tracker[ComposedSchema], + definitions: List[(String, Tracker[Schema[_]])], + concreteTypes: List[PropMeta[ScalaLanguage]], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[List[SuperClass[ScalaLanguage]]] = { + import Sc._ + + for { + a <- extractSuperClass(elem, definitions) + supper <- a.flatTraverse { case (clsName, _extends, interfaces) => + val concreteInterfacesWithClass = for { + interface <- interfaces + (cls, tracker) <- definitions + result <- tracker + .refine[Tracker[Schema[_]]] { + case x: ComposedSchema if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x + }( + identity _ + ) + .orRefine { case x: Schema[_] if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x }(identity _) + .toOption + } yield cls -> result + val (_, concreteInterfaces) = concreteInterfacesWithClass.unzip + val classMapping = (for { + (cls, schema) <- concreteInterfacesWithClass + (name, _) <- schema.downField("properties", _.getProperties).indexedDistribute.value + } yield (name, cls)).toMap + for { + _extendsProps <- extractProperties(_extends) + requiredFields = getRequiredFieldsRec(_extends) ++ concreteInterfaces.flatMap(getRequiredFieldsRec) + _withProps <- concreteInterfaces.traverse(extractProperties) + props = _extendsProps ++ _withProps.flatten + (params, _) <- prepareProperties( + NonEmptyList.of(clsName), + classMapping, + props, + requiredFields, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + interfacesCls = interfaces.flatMap(_.downField("$ref", _.get$ref).unwrapTracker.map(_.split("/").last)) + tpe <- parseTypeName(clsName) + + discriminators <- (_extends :: concreteInterfaces).flatTraverse( + _.refine[Target[List[Discriminator[ScalaLanguage]]]] { case m: ObjectSchema => m }(m => Discriminator.fromSchema(m).map(_.toList)) + .getOrElse(List.empty[Discriminator[ScalaLanguage]].pure[Target]) + ) + } yield tpe + .map( + SuperClass[ScalaLanguage]( + clsName, + _, + interfacesCls, + params, + discriminators + ) + ) + .toList + } + + } yield supper + } + + private[this] def fromModel( + clsName: NonEmptyList[String], + model: Tracker[Schema[_]], + parents: List[SuperClass[ScalaLanguage]], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Either[String, ClassDefinition[ScalaLanguage]]] = { + import Sc._ + + for { + props <- extractProperties(model) + requiredFields = getRequiredFieldsRec(model) + (params, nestedDefinitions) <- prepareProperties( + clsName, + Map.empty, + props, + requiredFields, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + encoder <- encodeModel(clsName.last, dtoPackage, params, parents) + decoder <- decodeModel(clsName.last, dtoPackage, supportPackage, params, parents) + tpe <- parseTypeName(clsName.last) + fullType <- selectType(dtoPackage.foldRight(clsName)((x, xs) => xs.prepend(x))) + staticDefns <- renderDTOStaticDefns(clsName.last, List.empty, encoder, decoder, params) + nestedClasses <- nestedDefinitions.flatTraverse { + case classDefinition: ClassDefinition[ScalaLanguage] => + for { + widenClass <- widenClassDefinition(classDefinition.cls) + companionTerm <- pureTermName(classDefinition.name) + companionDefinition <- wrapToObject(companionTerm, classDefinition.staticDefns.extraImports, classDefinition.staticDefns.definitions) + widenCompanion <- companionDefinition.traverse(widenObjectDefinition) + } yield List(widenClass) ++ widenCompanion.fold(classDefinition.staticDefns.definitions)(List(_)) + case enumDefinition: EnumDefinition[ScalaLanguage] => + for { + widenClass <- widenClassDefinition(enumDefinition.cls) + companionTerm <- pureTermName(enumDefinition.name) + companionDefinition <- wrapToObject(companionTerm, enumDefinition.staticDefns.extraImports, enumDefinition.staticDefns.definitions) + widenCompanion <- companionDefinition.traverse(widenObjectDefinition) + } yield List(widenClass) ++ widenCompanion.fold(enumDefinition.staticDefns.definitions)(List(_)) + } + defn <- renderDTOClass(clsName.last, supportPackage, params, parents) + } yield { + val finalStaticDefns = staticDefns.copy(definitions = staticDefns.definitions ++ nestedClasses) + if (parents.isEmpty && props.isEmpty) Left("Entity isn't model"): Either[String, ClassDefinition[ScalaLanguage]] + else tpe.toRight("Empty entity name").map(ClassDefinition[ScalaLanguage](clsName.last, _, fullType, defn, finalStaticDefns, parents)) + } + + } + + private def prepareProperties( + clsName: NonEmptyList[String], + propertyToTypeLookup: Map[String, String], + props: List[(String, Tracker[Schema[_]])], + requiredFields: List[String], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[(List[ProtocolParameter[ScalaLanguage]], List[NestedProtocolElems[ScalaLanguage]])] = { + import Sc._ + def getClsName(name: String): NonEmptyList[String] = propertyToTypeLookup.get(name).map(NonEmptyList.of(_)).getOrElse(clsName) + + def processProperty(name: String, schema: Tracker[Schema[_]]): Target[Option[Either[String, NestedProtocolElems[ScalaLanguage]]]] = + for { + nestedClassName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) + defn <- schema + .refine[Target[Option[Either[String, NestedProtocolElems[ScalaLanguage]]]]] { case x: ObjectSchema => x }(o => + for { + defn <- fromModel( + nestedClassName, + o, + List.empty, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + } yield Option(defn) + ) + .orRefine { case o: ComposedSchema => o }(o => + for { + parents <- extractParents(o, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) + maybeClassDefinition <- fromModel( + nestedClassName, + o, + parents, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + } yield Option(maybeClassDefinition) + ) + .orRefine { case a: ArraySchema => a }(_.downField("items", _.getItems()).indexedCosequence.flatTraverse(processProperty(name, _))) + .orRefine { case s: StringSchema if Option(s.getEnum).map(_.asScala).exists(_.nonEmpty) => s }(s => + fromEnum(nestedClassName.last, s, dtoPackage, components).map(Option(_)) + ) + .getOrElse(Option.empty[Either[String, NestedProtocolElems[ScalaLanguage]]].pure[Target]) + } yield defn + + for { + paramsAndNestedDefinitions <- props.traverse[Target, (Tracker[ProtocolParameter[ScalaLanguage]], Option[NestedProtocolElems[ScalaLanguage]])] { + case (name, schema) => + for { + typeName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) + tpe <- selectType(typeName) + maybeNestedDefinition <- processProperty(name, schema) + resolvedType <- SwaggerUtil.propMetaWithName(tpe, schema, components) + customType <- SwaggerUtil.customTypeName(schema) + propertyRequirement = getPropertyRequirement(schema, requiredFields.contains(name), defaultPropertyRequirement) + defValue <- defaultValue(typeName, schema, propertyRequirement, definitions) + fieldName <- formatFieldName(name) + parameter <- transformProperty(getClsName(name).last, dtoPackage, supportPackage, concreteTypes)( + name, + fieldName, + schema, + resolvedType, + propertyRequirement, + customType.isDefined, + defValue + ) + } yield (Tracker.cloneHistory(schema, parameter), maybeNestedDefinition.flatMap(_.toOption)) + } + (params, nestedDefinitions) = paramsAndNestedDefinitions.unzip + deduplicatedParams <- deduplicateParams(params) + unconflictedParams <- fixConflictingNames(deduplicatedParams) + } yield (unconflictedParams, nestedDefinitions.flatten) + } + + private def deduplicateParams( + params: List[Tracker[ProtocolParameter[ScalaLanguage]]] + )(implicit Sw: SwaggerTerms[ScalaLanguage, Target], Sc: LanguageTerms[ScalaLanguage, Target]): Target[List[ProtocolParameter[ScalaLanguage]]] = { + import Sc._ + Foldable[List] + .foldLeftM[Target, Tracker[ProtocolParameter[ScalaLanguage]], List[ProtocolParameter[ScalaLanguage]]]( + params, + List.empty[ProtocolParameter[ScalaLanguage]] + ) { (s, ta) => + val a = ta.unwrapTracker + s.find(p => p.name == a.name) match { + case None => (a :: s).pure[Target] + case Some(duplicate) => + for { + newDefaultValue <- findCommonDefaultValue(ta.showHistory, a.defaultValue, duplicate.defaultValue) + newRawType <- findCommonRawType(ta.showHistory, a.rawType, duplicate.rawType) + } yield { + val emptyToNull = if (Set(a.emptyToNull, duplicate.emptyToNull).contains(EmptyIsNull)) EmptyIsNull else EmptyIsEmpty + val redactionBehaviour = if (Set(a.dataRedaction, duplicate.dataRedaction).contains(DataRedacted)) DataRedacted else DataVisible + val mergedParameter = ProtocolParameter[ScalaLanguage]( + a.term, + a.baseType, + a.name, + a.dep, + newRawType, + a.readOnlyKey.orElse(duplicate.readOnlyKey), + emptyToNull, + redactionBehaviour, + a.propertyRequirement, + newDefaultValue, + a.propertyValidation + ) + mergedParameter :: s.filter(_.name != a.name) + } + } + } + .map(_.reverse) + } + + private def fixConflictingNames( + params: List[ProtocolParameter[ScalaLanguage]] + )(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[List[ProtocolParameter[ScalaLanguage]]] = { + import Lt._ + for { + paramsWithNames <- params.traverse(param => extractTermNameFromParam(param.term).map((_, param))) + counts = paramsWithNames.groupBy(_._1).view.mapValues(_.length).toMap + newParams <- paramsWithNames.traverse { case (name, param) => + if (counts.getOrElse(name, 0) > 1) { + for { + newTermName <- pureTermName(param.name.value) + newMethodParam <- alterMethodParameterName(param.term, newTermName) + } yield ProtocolParameter( + newMethodParam, + param.baseType, + param.name, + param.dep, + param.rawType, + param.readOnlyKey, + param.emptyToNull, + param.dataRedaction, + param.propertyRequirement, + param.defaultValue, + param.propertyValidation + ) + } else { + param.pure[Target] + } + } + } yield newParams + } + + private def modelTypeAlias(clsName: String, abstractModel: Tracker[Schema[_]], components: Tracker[Option[Components]])(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = { + import Fw._ + val model: Option[Tracker[ObjectSchema]] = abstractModel + .refine[Option[Tracker[ObjectSchema]]] { case m: ObjectSchema => m }(x => Option(x)) + .orRefine { case m: ComposedSchema => m }( + _.downField("allOf", _.getAllOf()).indexedCosequence + .get(1) + .flatMap( + _.refine { case o: ObjectSchema => o }(Option.apply) + .orRefineFallback(_ => None) + ) + ) + .orRefineFallback(_ => None) + for { + tpe <- model.fold[Target[scala.meta.Type]](objectType(None)) { m => + for { + tpeName <- SwaggerUtil.customTypeName[ScalaLanguage, Target, Tracker[ObjectSchema]](m) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](m, Tracker.cloneHistory(m, tpeName), components) + } yield declType + } + res <- typeAlias(clsName, tpe) + } yield res + } + + private def plainTypeAlias( + clsName: String + )(implicit Fw: FrameworkTerms[ScalaLanguage, Target], Sc: LanguageTerms[ScalaLanguage, Target]): Target[ProtocolElems[ScalaLanguage]] = { + import Fw._ + for { + tpe <- objectType(None) + res <- typeAlias(clsName, tpe) + } yield res + } + + private def typeAlias(clsName: String, tpe: scala.meta.Type): Target[ProtocolElems[ScalaLanguage]] = + (RandomType[ScalaLanguage](clsName, tpe): ProtocolElems[ScalaLanguage]).pure[Target] + + private def fromArray(clsName: String, arr: Tracker[ArraySchema], concreteTypes: List[PropMeta[ScalaLanguage]], components: Tracker[Option[Components]])( + implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = + for { + deferredTpe <- SwaggerUtil.modelMetaType(arr, components) + tpe <- extractArrayType(deferredTpe, concreteTypes) + ret <- typeAlias(clsName, tpe) + } yield ret + + /** returns objects grouped into hierarchies + */ + private def groupHierarchies( + definitions: Mappish[List, String, Tracker[Schema[_]]] + )(implicit + Sc: LanguageTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[(List[ClassParent[ScalaLanguage]], List[(String, Tracker[Schema[_]])])] = { + + def firstInHierarchy(model: Tracker[Schema[_]]): Option[Tracker[ObjectSchema]] = + model + .refine { case x: ComposedSchema => x } { elem => + definitions.value + .collectFirst { + case (clsName, element) + if elem.downField("allOf", _.getAllOf).exists(_.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$clsName"))) => + element + } + .flatMap( + _.refine { case x: ComposedSchema => x }(firstInHierarchy) + .orRefine { case o: ObjectSchema => o }(x => Option(x)) + .getOrElse(None) + ) + } + .getOrElse(None) + + def children(cls: String): List[ClassChild[ScalaLanguage]] = definitions.value.flatMap { case (clsName, comp) => + comp + .refine { case x: ComposedSchema => x }(comp => + if ( + comp + .downField("allOf", _.getAllOf()) + .exists(x => x.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$cls"))) + ) { + Some(ClassChild(clsName, comp, children(clsName), getRequiredFieldsRec(comp))) + } else None + ) + .getOrElse(None) + } + + def classHierarchy(cls: String, model: Tracker[Schema[_]]): Target[Option[ClassParent[ScalaLanguage]]] = + model + .refine { case c: ComposedSchema => c }(c => + firstInHierarchy(c) + .fold(Option.empty[Discriminator[ScalaLanguage]].pure[Target])(Discriminator.fromSchema[ScalaLanguage, Target]) + .map(_.map((_, getRequiredFieldsRec(c)))) + ) + .orRefine { case x: Schema[_] => x }(m => Discriminator.fromSchema(m).map(_.map((_, getRequiredFieldsRec(m))))) + .getOrElse(Option.empty[(Discriminator[ScalaLanguage], List[String])].pure[Target]) + .map(_.map { case (discriminator, reqFields) => ClassParent(cls, model, children(cls), discriminator, reqFields) }) + + Sw.log.function("groupHierarchies")( + definitions.value + .traverse { case (cls, model) => + for { + hierarchy <- classHierarchy(cls, model) + } yield hierarchy.filterNot(_.children.isEmpty).toLeft((cls, model)) + } + .map(_.partitionEither[ClassParent[ScalaLanguage], (String, Tracker[Schema[_]])](identity)) + ) + } + + private def defaultValue( + name: NonEmptyList[String], + schema: Tracker[Schema[_]], + requirement: PropertyRequirement, + definitions: List[(String, Tracker[Schema[_]])] + )(implicit + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target] + ): Target[Option[scala.meta.Term]] = { + import Sc._ + import Cl._ + val empty = Option.empty[scala.meta.Term].pure[Target] + schema.downField("$ref", _.get$ref()).indexedDistribute match { + case Some(ref) => + definitions + .collectFirst { + case (cls, refSchema) if ref.unwrapTracker.endsWith(s"/$cls") => + defaultValue(NonEmptyList.of(cls), refSchema, requirement, definitions) + } + .getOrElse(empty) + case None => + schema + .refine { case map: MapSchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => map }(map => + for { + customTpe <- SwaggerUtil.customMapTypeName(map) + result <- customTpe.fold(emptyMap().map(Option(_)))(_ => empty) + } yield result + ) + .orRefine { case arr: ArraySchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => arr }( + arr => + for { + customTpe <- SwaggerUtil.customArrayTypeName(arr) + result <- customTpe.fold(emptyArray().map(Option(_)))(_ => empty) + } yield result + ) + .orRefine { case p: BooleanSchema => p }(p => Default(p).extract[Boolean].fold(empty)(litBoolean(_).map(Some(_)))) + .orRefine { case p: NumberSchema if p.getFormat == "double" => p }(p => Default(p).extract[Double].fold(empty)(litDouble(_).map(Some(_)))) + .orRefine { case p: NumberSchema if p.getFormat == "float" => p }(p => Default(p).extract[Float].fold(empty)(litFloat(_).map(Some(_)))) + .orRefine { case p: IntegerSchema if p.getFormat == "int32" => p }(p => Default(p).extract[Int].fold(empty)(litInt(_).map(Some(_)))) + .orRefine { case p: IntegerSchema if p.getFormat == "int64" => p }(p => Default(p).extract[Long].fold(empty)(litLong(_).map(Some(_)))) + .orRefine { case p: StringSchema if Option(p.getEnum).map(_.asScala).exists(_.nonEmpty) => p }(p => + Default(p).extract[String] match { + case Some(defaultEnumValue) => + for { + enumName <- formatEnumName(defaultEnumValue) + result <- selectTerm(name.append(enumName)) + } yield Some(result) + case None => empty + } + ) + .orRefine { case p: StringSchema => p }(p => Default(p).extract[String].fold(empty)(litString(_).map(Some(_)))) + .getOrElse(empty) + } + } + private def suffixClsName(prefix: String, clsName: String): Pat.Var = Pat.Var(Term.Name(s"${prefix}${clsName}")) private def lookupTypeName(tpeName: String, concreteTypes: List[PropMeta[ScalaLanguage]])(f: Type => Type): Option[Type] = @@ -54,7 +861,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa .map(_.tpe) .map(f) - override def renderMembers(clsName: String, elems: RenderedEnum[ScalaLanguage]) = { + private def renderMembers(clsName: String, elems: RenderedEnum[ScalaLanguage]) = { val fields = elems match { case RenderedStringEnum(elems) => elems.map { case (value, termName, _) => @@ -77,27 +884,27 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa """)) } - override def encodeEnum(clsName: String, tpe: Type): Target[Option[Defn]] = + private def encodeEnum(clsName: String, tpe: Type): Target[Option[Defn]] = Target.pure(Some(q""" implicit val ${suffixClsName("encode", clsName)}: _root_.io.circe.Encoder[${Type.Name(clsName)}] = _root_.io.circe.Encoder[${tpe}].contramap(_.value) """)) - override def decodeEnum(clsName: String, tpe: Type): Target[Option[Defn]] = + private def decodeEnum(clsName: String, tpe: Type): Target[Option[Defn]] = Target.pure(Some(q""" implicit val ${suffixClsName("decode", clsName)}: _root_.io.circe.Decoder[${Type.Name(clsName)}] = _root_.io.circe.Decoder[${tpe}].emap(value => from(value).toRight(${Term .Interpolate(Term.Name("s"), List(Lit.String(""), Lit.String(s" not a member of ${clsName}")), List(Term.Name("value")))})) """)) - override def renderClass(clsName: String, tpe: scala.meta.Type, elems: RenderedEnum[ScalaLanguage]) = + private def renderClass(clsName: String, tpe: scala.meta.Type, elems: RenderedEnum[ScalaLanguage]) = Target.pure(q""" sealed abstract class ${Type.Name(clsName)}(val value: ${tpe}) extends _root_.scala.Product with _root_.scala.Serializable { override def toString: String = value.toString } """) - override def renderStaticDefns( + private def renderStaticDefns( clsName: String, tpe: scala.meta.Type, members: Option[scala.meta.Defn.Object], @@ -132,7 +939,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa override def buildAccessor(clsName: String, termName: String) = Target.pure(q"${Term.Name(clsName)}.${Term.Name(termName)}") - override def extractProperties(swagger: Tracker[Schema[_]]) = + private def extractProperties(swagger: Tracker[Schema[_]]) = swagger .refine[Target[List[(String, Tracker[Schema[_]])]]] { case o: ObjectSchema => o }(m => Target.pure(m.downField("properties", _.getProperties()).indexedCosequence.value) @@ -150,7 +957,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa ) .getOrElse(Target.pure(List.empty[(String, Tracker[Schema[_]])])) - override def transformProperty( + private def transformProperty( clsName: String, dtoPackage: List[String], supportPackage: List[String], @@ -159,7 +966,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa name: String, fieldName: String, property: Tracker[Schema[_]], - meta: ResolvedType[ScalaLanguage], + meta: core.ResolvedType[ScalaLanguage], requirement: PropertyRequirement, isCustomType: Boolean, defaultValue: Option[scala.meta.Term] @@ -248,7 +1055,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa ) } - override def renderDTOClass( + private def renderDTOClass( clsName: String, supportPackage: List[String], selfParams: List[ProtocolParameter[ScalaLanguage]], @@ -290,7 +1097,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa Target.pure(code) } - override def encodeModel( + private def encodeModel( clsName: String, dtoPackage: List[String], selfParams: List[ProtocolParameter[ScalaLanguage]], @@ -355,7 +1162,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa """)) } - override def decodeModel( + private def decodeModel( clsName: String, dtoPackage: List[String], supportPackage: List[String], @@ -465,7 +1272,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa """) } - override def renderDTOStaticDefns( + private def renderDTOStaticDefns( clsName: String, deps: List[scala.meta.Term.Name], encoder: Option[scala.meta.Defn.Val], @@ -485,7 +1292,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa ) } - override def extractArrayType(arr: core.ResolvedType[ScalaLanguage], concreteTypes: List[PropMeta[ScalaLanguage]]) = + private def extractArrayType(arr: core.ResolvedType[ScalaLanguage], concreteTypes: List[PropMeta[ScalaLanguage]]) = for { result <- arr match { case core.Resolved(tpe, dep, default, _) => Target.pure(tpe) @@ -506,10 +1313,10 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa } } yield result - override def extractConcreteTypes(definitions: Either[String, List[PropMeta[ScalaLanguage]]]) = + private def extractConcreteTypes(definitions: Either[String, List[PropMeta[ScalaLanguage]]]) = definitions.fold[Target[List[PropMeta[ScalaLanguage]]]](Target.raiseUserError _, Target.pure _) - override def protocolImports() = + private def protocolImports() = Target.pure( List( q"import cats.syntax.either._", @@ -528,7 +1335,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa ) } - override def packageObjectImports() = + private def packageObjectImports() = Target.pure(List.empty) override def generateSupportDefinitions() = { @@ -567,7 +1374,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa Target.pure(List(presenceDefinition)) } - override def packageObjectContents() = + private def packageObjectContents() = Target.pure( List( q"implicit val guardrailDecodeInstant: _root_.io.circe.Decoder[java.time.Instant] = _root_.io.circe.Decoder[java.time.Instant].or(_root_.io.circe.Decoder[_root_.scala.Long].map(java.time.Instant.ofEpochMilli))", @@ -587,9 +1394,9 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa ) ) - override def implicitsObject() = Target.pure(None) + private def implicitsObject() = Target.pure(None) - override def extractSuperClass( + private def extractSuperClass( swagger: Tracker[ComposedSchema], definitions: List[(String, Tracker[Schema[_]])] ) = { @@ -612,7 +1419,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa allParents(swagger) } - override def renderADTStaticDefns( + private def renderADTStaticDefns( clsName: String, discriminator: Discriminator[ScalaLanguage], encoder: Option[scala.meta.Defn.Val], @@ -630,7 +1437,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa ) ) - override def decodeADT(clsName: String, discriminator: Discriminator[ScalaLanguage], children: List[String] = Nil) = { + private def decodeADT(clsName: String, discriminator: Discriminator[ScalaLanguage], children: List[String] = Nil) = { val (childrenCases, childrenDiscriminators) = children.map { child => val discriminatorValue = discriminator.mapping .collectFirst { case (value, elem) if elem.name == child => value } @@ -653,7 +1460,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa Target.pure(Some(code)) } - override def encodeADT(clsName: String, discriminator: Discriminator[ScalaLanguage], children: List[String] = Nil) = { + private def encodeADT(clsName: String, discriminator: Discriminator[ScalaLanguage], children: List[String] = Nil) = { val childrenCases = children.map { child => val discriminatorValue = discriminator.mapping .collectFirst { case (value, elem) if elem.name == child => value } @@ -667,7 +1474,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa Target.pure(Some(code)) } - override def renderSealedTrait( + private def renderSealedTrait( className: String, params: List[ProtocolParameter[ScalaLanguage]], discriminator: Discriminator[ScalaLanguage], diff --git a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala index ce0e8c4f7a..54e3182b04 100644 --- a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala +++ b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala @@ -1,17 +1,44 @@ package dev.guardrail.generators.scala.circe import _root_.io.swagger.v3.oas.models.media.Schema -import dev.guardrail.Target -import dev.guardrail.core.Tracker -import dev.guardrail.generators.scala.{ CirceRefinedModelGenerator, ScalaLanguage } -import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader } -import dev.guardrail.terms.ProtocolTerms -import dev.guardrail.terms.protocol._ +import _root_.io.swagger.v3.oas.models.media.{ Discriminator => _, _ } +import _root_.io.swagger.v3.oas.models.{ Components, OpenAPI } +import cats.Foldable +import cats.Monad +import cats.data.{ NonEmptyList, NonEmptyVector } import cats.implicits._ - +import scala.jdk.CollectionConverters._ import scala.meta._ import scala.reflect.runtime.universe.typeTag +import dev.guardrail.core +import dev.guardrail.core.extract.Default +import dev.guardrail.core.extract.{ DataRedaction, EmptyValueIsNull } +import dev.guardrail.core.{ DataRedacted, DataVisible, EmptyIsEmpty, EmptyIsNull, LiteralRawType, Mappish, ReifiedRawType, SupportDefinition, Tracker } +import dev.guardrail.generators.ProtocolGenerator.{ WrapEnumSchema, wrapNumberEnumSchema, wrapObjectEnumSchema, wrapStringEnumSchema } +import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassParent } +import dev.guardrail.generators.scala.circe.CirceProtocolGenerator.WithValidations +import dev.guardrail.generators.scala.{ CirceModelGenerator, CirceRefinedModelGenerator, ScalaGenerator, ScalaLanguage } +import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader } +import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName } +import dev.guardrail.terms.framework.FrameworkTerms +import dev.guardrail.terms.protocol._ +import dev.guardrail.terms.{ + CollectionsLibTerms, + HeldEnum, + IntHeldEnum, + LanguageTerms, + LongHeldEnum, + ProtocolTerms, + RenderedEnum, + RenderedIntEnum, + RenderedLongEnum, + RenderedStringEnum, + StringHeldEnum, + SwaggerTerms +} +import dev.guardrail.{ SwaggerUtil, Target, UserError } + class CirceRefinedProtocolGeneratorLoader extends ProtocolGeneratorLoader { type L = ScalaLanguage def reified = typeTag[Target[ScalaLanguage]] @@ -31,7 +58,7 @@ object StandardContainers { object CirceRefinedProtocolGenerator { def apply(circeRefinedVersion: CirceRefinedModelGenerator): ProtocolTerms[ScalaLanguage, Target] = - fromGenerator(CirceProtocolGenerator.withValidations(circeRefinedVersion.toCirce, applyValidations)) + new CirceRefinedProtocolGenerator(circeRefinedVersion.toCirce, applyValidations) def applyValidations(className: String, tpe: Type, prop: Tracker[Schema[_]]): Target[Type] = { import scala.meta._ @@ -101,10 +128,1221 @@ object CirceRefinedProtocolGenerator { case _ => Target.pure(tpe) } } +} + +class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator, applyValidations: WithValidations) + extends ProtocolTerms[ScalaLanguage, Target] { + + override implicit def MonadF: Monad[Target] = Target.targetInstances + + import Target.targetInstances // TODO: Remove me. This resolves implicit ambiguity from MonadChain + + override def fromSwagger( + swagger: Tracker[OpenAPI], + dtoPackage: List[String], + supportPackage: NonEmptyList[String], + defaultPropertyRequirement: PropertyRequirement + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolDefinitions[ScalaLanguage]] = { + import Sc._ + + val components = swagger.downField("components", _.getComponents()) + val definitions = components.flatDownField("schemas", _.getSchemas()).indexedCosequence + Sw.log.function("ProtocolGenerator.fromSwagger")(for { + (hierarchies, definitionsWithoutPoly) <- groupHierarchies(definitions) + + concreteTypes <- SwaggerUtil.extractConcreteTypes[ScalaLanguage, Target](definitions.value, components) + polyADTs <- hierarchies.traverse(fromPoly(_, concreteTypes, definitions.value, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components)) + elems <- definitionsWithoutPoly.traverse { case (clsName, model) => + model + .refine { case c: ComposedSchema => c }(comp => + for { + formattedClsName <- formatTypeName(clsName) + parents <- extractParents(comp, definitions.value, concreteTypes, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components) + model <- fromModel( + clsName = NonEmptyList.of(formattedClsName), + model = comp, + parents = parents, + concreteTypes = concreteTypes, + definitions = definitions.value, + dtoPackage = dtoPackage, + supportPackage = supportPackage.toList, + defaultPropertyRequirement = defaultPropertyRequirement, + components = components + ) + alias <- modelTypeAlias(formattedClsName, comp, components) + } yield model.getOrElse(alias) + ) + .orRefine { case a: ArraySchema => a }(arr => + for { + formattedClsName <- formatTypeName(clsName) + array <- fromArray(formattedClsName, arr, concreteTypes, components) + } yield array + ) + .orRefine { case o: ObjectSchema => o }(m => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum[Object](formattedClsName, m, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + m, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + alias <- modelTypeAlias(formattedClsName, m, components) + } yield enum.orElse(model).getOrElse(alias) + ) + .orRefine { case x: StringSchema => x }(x => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum(formattedClsName, x, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + x, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + alias <- typeAlias(formattedClsName, declType) + } yield enum.orElse(model).getOrElse(alias) + ) + .orRefine { case x: IntegerSchema => x }(x => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum(formattedClsName, x, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + x, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + alias <- typeAlias(formattedClsName, declType) + } yield enum.orElse(model).getOrElse(alias) + ) + .valueOr(x => + for { + formattedClsName <- formatTypeName(clsName) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + res <- typeAlias(formattedClsName, declType) + } yield res + ) + } + protoImports <- protocolImports() + pkgImports <- packageObjectImports() + pkgObjectContents <- packageObjectContents() + implicitsObject <- implicitsObject() + + polyADTElems <- ProtocolElems.resolve[ScalaLanguage, Target](polyADTs) + strictElems <- ProtocolElems.resolve[ScalaLanguage, Target](elems) + } yield ProtocolDefinitions[ScalaLanguage](strictElems ++ polyADTElems, protoImports, pkgImports, pkgObjectContents, implicitsObject)) + } + + private[this] def getRequiredFieldsRec(root: Tracker[Schema[_]]): List[String] = { + @scala.annotation.tailrec + def work(values: List[Tracker[Schema[_]]], acc: List[String]): List[String] = { + val required: List[String] = values.flatMap(_.downField("required", _.getRequired()).unwrapTracker) + val next: List[Tracker[Schema[_]]] = + for { + a <- values + b <- a.refine { case x: ComposedSchema => x }(_.downField("allOf", _.getAllOf())).toOption.toList + c <- b.indexedDistribute + } yield c + + val newRequired = acc ++ required + + next match { + case next @ (_ :: _) => work(next, newRequired) + case Nil => newRequired + } + } + work(List(root), Nil) + } + + private[this] def fromEnum[A]( + clsName: String, + schema: Tracker[Schema[A]], + dtoPackage: List[String], + components: Tracker[Option[Components]] + )(implicit + P: ProtocolTerms[ScalaLanguage, Target], + F: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target], + wrapEnumSchema: WrapEnumSchema[A] + ): Target[Either[String, EnumDefinition[ScalaLanguage]]] = { + import Sc._ + import Sw._ + + def validProg(held: HeldEnum, tpe: scala.meta.Type, fullType: scala.meta.Type): Target[EnumDefinition[ScalaLanguage]] = + for { + (pascalValues, wrappedValues) <- held match { + case StringHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(elem) + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedStringEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + case IntHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedIntEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + case LongHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedLongEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + } + members <- renderMembers(clsName, wrappedValues) + encoder <- encodeEnum(clsName, tpe) + decoder <- decodeEnum(clsName, tpe) + + defn <- renderClass(clsName, tpe, wrappedValues) + staticDefns <- renderStaticDefns(clsName, tpe, members, pascalValues, encoder, decoder) + classType <- pureTypeName(clsName) + } yield EnumDefinition[ScalaLanguage](clsName, classType, fullType, wrappedValues, defn, staticDefns) + + for { + enum <- extractEnum(schema.map(wrapEnumSchema)) + customTpeName <- SwaggerUtil.customTypeName(schema) + (tpe, _) <- SwaggerUtil.determineTypeName(schema, Tracker.cloneHistory(schema, customTpeName), components) + fullType <- selectType(NonEmptyList.ofInitLast(dtoPackage, clsName)) + res <- enum.traverse(validProg(_, tpe, fullType)) + } yield res + } + + private[this] def getPropertyRequirement( + schema: Tracker[Schema[_]], + isRequired: Boolean, + defaultPropertyRequirement: PropertyRequirement + ): PropertyRequirement = + (for { + isNullable <- schema.downField("nullable", _.getNullable) + } yield (isRequired, isNullable) match { + case (true, None) => PropertyRequirement.Required + case (true, Some(false)) => PropertyRequirement.Required + case (true, Some(true)) => PropertyRequirement.RequiredNullable + case (false, None) => defaultPropertyRequirement + case (false, Some(false)) => PropertyRequirement.Optional + case (false, Some(true)) => PropertyRequirement.OptionalNullable + }).unwrapTracker - def renderDTOStaticDefns( - base: ProtocolTerms[ScalaLanguage, Target] + /** Handle polymorphic model + */ + private[this] def fromPoly( + hierarchy: ClassParent[ScalaLanguage], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = { + import Sc._ + + def child(hierarchy: ClassHierarchy[ScalaLanguage]): List[String] = + hierarchy.children.map(_.name) ::: hierarchy.children.flatMap(child) + def parent(hierarchy: ClassHierarchy[ScalaLanguage]): List[String] = + if (hierarchy.children.nonEmpty) hierarchy.name :: hierarchy.children.flatMap(parent) + else Nil + + val children = child(hierarchy).diff(parent(hierarchy)).distinct + val discriminator = hierarchy.discriminator + + for { + parents <- hierarchy.model + .refine[Target[List[SuperClass[ScalaLanguage]]]] { case c: ComposedSchema => c }( + extractParents(_, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) + ) + .getOrElse(List.empty[SuperClass[ScalaLanguage]].pure[Target]) + props <- extractProperties(hierarchy.model) + requiredFields = hierarchy.required ::: hierarchy.children.flatMap(_.required) + params <- props.traverse { case (name, prop) => + for { + typeName <- formatTypeName(name).map(formattedName => NonEmptyList.of(hierarchy.name, formattedName)) + propertyRequirement = getPropertyRequirement(prop, requiredFields.contains(name), defaultPropertyRequirement) + customType <- SwaggerUtil.customTypeName(prop) + resolvedType <- SwaggerUtil + .propMeta[ScalaLanguage, Target]( + prop, + components + ) // TODO: This should be resolved via an alternate mechanism that maintains references all the way through, instead of re-deriving and assuming that references are valid + defValue <- defaultValue(typeName, prop, propertyRequirement, definitions) + fieldName <- formatFieldName(name) + res <- transformProperty(hierarchy.name, dtoPackage, supportPackage, concreteTypes)( + name, + fieldName, + prop, + resolvedType, + propertyRequirement, + customType.isDefined, + defValue + ) + } yield res + } + definition <- renderSealedTrait(hierarchy.name, params, discriminator, parents, children) + encoder <- encodeADT(hierarchy.name, hierarchy.discriminator, children) + decoder <- decodeADT(hierarchy.name, hierarchy.discriminator, children) + staticDefns <- renderADTStaticDefns(hierarchy.name, discriminator, encoder, decoder) + tpe <- pureTypeName(hierarchy.name) + fullType <- selectType(NonEmptyList.fromList(dtoPackage :+ hierarchy.name).getOrElse(NonEmptyList.of(hierarchy.name))) + } yield ADT[ScalaLanguage]( + name = hierarchy.name, + tpe = tpe, + fullType = fullType, + trt = definition, + staticDefns = staticDefns + ) + } + + private def extractParents( + elem: Tracker[ComposedSchema], + definitions: List[(String, Tracker[Schema[_]])], + concreteTypes: List[PropMeta[ScalaLanguage]], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[List[SuperClass[ScalaLanguage]]] = { + import Sc._ + + for { + a <- extractSuperClass(elem, definitions) + supper <- a.flatTraverse { case (clsName, _extends, interfaces) => + val concreteInterfacesWithClass = for { + interface <- interfaces + (cls, tracker) <- definitions + result <- tracker + .refine[Tracker[Schema[_]]] { + case x: ComposedSchema if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x + }( + identity _ + ) + .orRefine { case x: Schema[_] if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x }(identity _) + .toOption + } yield cls -> result + val (_, concreteInterfaces) = concreteInterfacesWithClass.unzip + val classMapping = (for { + (cls, schema) <- concreteInterfacesWithClass + (name, _) <- schema.downField("properties", _.getProperties).indexedDistribute.value + } yield (name, cls)).toMap + for { + _extendsProps <- extractProperties(_extends) + requiredFields = getRequiredFieldsRec(_extends) ++ concreteInterfaces.flatMap(getRequiredFieldsRec) + _withProps <- concreteInterfaces.traverse(extractProperties) + props = _extendsProps ++ _withProps.flatten + (params, _) <- prepareProperties( + NonEmptyList.of(clsName), + classMapping, + props, + requiredFields, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + interfacesCls = interfaces.flatMap(_.downField("$ref", _.get$ref).unwrapTracker.map(_.split("/").last)) + tpe <- parseTypeName(clsName) + + discriminators <- (_extends :: concreteInterfaces).flatTraverse( + _.refine[Target[List[Discriminator[ScalaLanguage]]]] { case m: ObjectSchema => m }(m => Discriminator.fromSchema(m).map(_.toList)) + .getOrElse(List.empty[Discriminator[ScalaLanguage]].pure[Target]) + ) + } yield tpe + .map( + SuperClass[ScalaLanguage]( + clsName, + _, + interfacesCls, + params, + discriminators + ) + ) + .toList + } + + } yield supper + } + + private[this] def fromModel( + clsName: NonEmptyList[String], + model: Tracker[Schema[_]], + parents: List[SuperClass[ScalaLanguage]], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Either[String, ClassDefinition[ScalaLanguage]]] = { + import Sc._ + + for { + props <- extractProperties(model) + requiredFields = getRequiredFieldsRec(model) + (params, nestedDefinitions) <- prepareProperties( + clsName, + Map.empty, + props, + requiredFields, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + encoder <- encodeModel(clsName.last, dtoPackage, params, parents) + decoder <- decodeModel(clsName.last, dtoPackage, supportPackage, params, parents) + tpe <- parseTypeName(clsName.last) + fullType <- selectType(dtoPackage.foldRight(clsName)((x, xs) => xs.prepend(x))) + staticDefns <- renderDTOStaticDefns(clsName.last, List.empty, encoder, decoder, params) + nestedClasses <- nestedDefinitions.flatTraverse { + case classDefinition: ClassDefinition[ScalaLanguage] => + for { + widenClass <- widenClassDefinition(classDefinition.cls) + companionTerm <- pureTermName(classDefinition.name) + companionDefinition <- wrapToObject(companionTerm, classDefinition.staticDefns.extraImports, classDefinition.staticDefns.definitions) + widenCompanion <- companionDefinition.traverse(widenObjectDefinition) + } yield List(widenClass) ++ widenCompanion.fold(classDefinition.staticDefns.definitions)(List(_)) + case enumDefinition: EnumDefinition[ScalaLanguage] => + for { + widenClass <- widenClassDefinition(enumDefinition.cls) + companionTerm <- pureTermName(enumDefinition.name) + companionDefinition <- wrapToObject(companionTerm, enumDefinition.staticDefns.extraImports, enumDefinition.staticDefns.definitions) + widenCompanion <- companionDefinition.traverse(widenObjectDefinition) + } yield List(widenClass) ++ widenCompanion.fold(enumDefinition.staticDefns.definitions)(List(_)) + } + defn <- renderDTOClass(clsName.last, supportPackage, params, parents) + } yield { + val finalStaticDefns = staticDefns.copy(definitions = staticDefns.definitions ++ nestedClasses) + if (parents.isEmpty && props.isEmpty) Left("Entity isn't model"): Either[String, ClassDefinition[ScalaLanguage]] + else tpe.toRight("Empty entity name").map(ClassDefinition[ScalaLanguage](clsName.last, _, fullType, defn, finalStaticDefns, parents)) + } + + } + + private def prepareProperties( + clsName: NonEmptyList[String], + propertyToTypeLookup: Map[String, String], + props: List[(String, Tracker[Schema[_]])], + requiredFields: List[String], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[(List[ProtocolParameter[ScalaLanguage]], List[NestedProtocolElems[ScalaLanguage]])] = { + import Sc._ + def getClsName(name: String): NonEmptyList[String] = propertyToTypeLookup.get(name).map(NonEmptyList.of(_)).getOrElse(clsName) + + def processProperty(name: String, schema: Tracker[Schema[_]]): Target[Option[Either[String, NestedProtocolElems[ScalaLanguage]]]] = + for { + nestedClassName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) + defn <- schema + .refine[Target[Option[Either[String, NestedProtocolElems[ScalaLanguage]]]]] { case x: ObjectSchema => x }(o => + for { + defn <- fromModel( + nestedClassName, + o, + List.empty, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + } yield Option(defn) + ) + .orRefine { case o: ComposedSchema => o }(o => + for { + parents <- extractParents(o, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) + maybeClassDefinition <- fromModel( + nestedClassName, + o, + parents, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + } yield Option(maybeClassDefinition) + ) + .orRefine { case a: ArraySchema => a }(_.downField("items", _.getItems()).indexedCosequence.flatTraverse(processProperty(name, _))) + .orRefine { case s: StringSchema if Option(s.getEnum).map(_.asScala).exists(_.nonEmpty) => s }(s => + fromEnum(nestedClassName.last, s, dtoPackage, components).map(Option(_)) + ) + .getOrElse(Option.empty[Either[String, NestedProtocolElems[ScalaLanguage]]].pure[Target]) + } yield defn + + for { + paramsAndNestedDefinitions <- props.traverse[Target, (Tracker[ProtocolParameter[ScalaLanguage]], Option[NestedProtocolElems[ScalaLanguage]])] { + case (name, schema) => + for { + typeName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) + tpe <- selectType(typeName) + maybeNestedDefinition <- processProperty(name, schema) + resolvedType <- SwaggerUtil.propMetaWithName(tpe, schema, components) + customType <- SwaggerUtil.customTypeName(schema) + propertyRequirement = getPropertyRequirement(schema, requiredFields.contains(name), defaultPropertyRequirement) + defValue <- defaultValue(typeName, schema, propertyRequirement, definitions) + fieldName <- formatFieldName(name) + parameter <- transformProperty(getClsName(name).last, dtoPackage, supportPackage, concreteTypes)( + name, + fieldName, + schema, + resolvedType, + propertyRequirement, + customType.isDefined, + defValue + ) + } yield (Tracker.cloneHistory(schema, parameter), maybeNestedDefinition.flatMap(_.toOption)) + } + (params, nestedDefinitions) = paramsAndNestedDefinitions.unzip + deduplicatedParams <- deduplicateParams(params) + unconflictedParams <- fixConflictingNames(deduplicatedParams) + } yield (unconflictedParams, nestedDefinitions.flatten) + } + + private def deduplicateParams( + params: List[Tracker[ProtocolParameter[ScalaLanguage]]] + )(implicit Sw: SwaggerTerms[ScalaLanguage, Target], Sc: LanguageTerms[ScalaLanguage, Target]): Target[List[ProtocolParameter[ScalaLanguage]]] = { + import Sc._ + Foldable[List] + .foldLeftM[Target, Tracker[ProtocolParameter[ScalaLanguage]], List[ProtocolParameter[ScalaLanguage]]]( + params, + List.empty[ProtocolParameter[ScalaLanguage]] + ) { (s, ta) => + val a = ta.unwrapTracker + s.find(p => p.name == a.name) match { + case None => (a :: s).pure[Target] + case Some(duplicate) => + for { + newDefaultValue <- findCommonDefaultValue(ta.showHistory, a.defaultValue, duplicate.defaultValue) + newRawType <- findCommonRawType(ta.showHistory, a.rawType, duplicate.rawType) + } yield { + val emptyToNull = if (Set(a.emptyToNull, duplicate.emptyToNull).contains(EmptyIsNull)) EmptyIsNull else EmptyIsEmpty + val redactionBehaviour = if (Set(a.dataRedaction, duplicate.dataRedaction).contains(DataRedacted)) DataRedacted else DataVisible + val mergedParameter = ProtocolParameter[ScalaLanguage]( + a.term, + a.baseType, + a.name, + a.dep, + newRawType, + a.readOnlyKey.orElse(duplicate.readOnlyKey), + emptyToNull, + redactionBehaviour, + a.propertyRequirement, + newDefaultValue, + a.propertyValidation + ) + mergedParameter :: s.filter(_.name != a.name) + } + } + } + .map(_.reverse) + } + + private def fixConflictingNames( + params: List[ProtocolParameter[ScalaLanguage]] + )(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[List[ProtocolParameter[ScalaLanguage]]] = { + import Lt._ + for { + paramsWithNames <- params.traverse(param => extractTermNameFromParam(param.term).map((_, param))) + counts = paramsWithNames.groupBy(_._1).view.mapValues(_.length).toMap + newParams <- paramsWithNames.traverse { case (name, param) => + if (counts.getOrElse(name, 0) > 1) { + for { + newTermName <- pureTermName(param.name.value) + newMethodParam <- alterMethodParameterName(param.term, newTermName) + } yield ProtocolParameter( + newMethodParam, + param.baseType, + param.name, + param.dep, + param.rawType, + param.readOnlyKey, + param.emptyToNull, + param.dataRedaction, + param.propertyRequirement, + param.defaultValue, + param.propertyValidation + ) + } else { + param.pure[Target] + } + } + } yield newParams + } + + private def modelTypeAlias(clsName: String, abstractModel: Tracker[Schema[_]], components: Tracker[Option[Components]])(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = { + import Fw._ + val model: Option[Tracker[ObjectSchema]] = abstractModel + .refine[Option[Tracker[ObjectSchema]]] { case m: ObjectSchema => m }(x => Option(x)) + .orRefine { case m: ComposedSchema => m }( + _.downField("allOf", _.getAllOf()).indexedCosequence + .get(1) + .flatMap( + _.refine { case o: ObjectSchema => o }(Option.apply) + .orRefineFallback(_ => None) + ) + ) + .orRefineFallback(_ => None) + for { + tpe <- model.fold[Target[scala.meta.Type]](objectType(None)) { m => + for { + tpeName <- SwaggerUtil.customTypeName[ScalaLanguage, Target, Tracker[ObjectSchema]](m) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](m, Tracker.cloneHistory(m, tpeName), components) + } yield declType + } + res <- typeAlias(clsName, tpe) + } yield res + } + + private def plainTypeAlias( + clsName: String + )(implicit Fw: FrameworkTerms[ScalaLanguage, Target], Sc: LanguageTerms[ScalaLanguage, Target]): Target[ProtocolElems[ScalaLanguage]] = { + import Fw._ + for { + tpe <- objectType(None) + res <- typeAlias(clsName, tpe) + } yield res + } + + private def typeAlias(clsName: String, tpe: scala.meta.Type): Target[ProtocolElems[ScalaLanguage]] = + (RandomType[ScalaLanguage](clsName, tpe): ProtocolElems[ScalaLanguage]).pure[Target] + + private def fromArray(clsName: String, arr: Tracker[ArraySchema], concreteTypes: List[PropMeta[ScalaLanguage]], components: Tracker[Option[Components]])( + implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = + for { + deferredTpe <- SwaggerUtil.modelMetaType(arr, components) + tpe <- extractArrayType(deferredTpe, concreteTypes) + ret <- typeAlias(clsName, tpe) + } yield ret + + /** returns objects grouped into hierarchies + */ + private def groupHierarchies( + definitions: Mappish[List, String, Tracker[Schema[_]]] + )(implicit + Sc: LanguageTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[(List[ClassParent[ScalaLanguage]], List[(String, Tracker[Schema[_]])])] = { + + def firstInHierarchy(model: Tracker[Schema[_]]): Option[Tracker[ObjectSchema]] = + model + .refine { case x: ComposedSchema => x } { elem => + definitions.value + .collectFirst { + case (clsName, element) + if elem.downField("allOf", _.getAllOf).exists(_.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$clsName"))) => + element + } + .flatMap( + _.refine { case x: ComposedSchema => x }(firstInHierarchy) + .orRefine { case o: ObjectSchema => o }(x => Option(x)) + .getOrElse(None) + ) + } + .getOrElse(None) + + def children(cls: String): List[ClassChild[ScalaLanguage]] = definitions.value.flatMap { case (clsName, comp) => + comp + .refine { case x: ComposedSchema => x }(comp => + if ( + comp + .downField("allOf", _.getAllOf()) + .exists(x => x.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$cls"))) + ) { + Some(ClassChild(clsName, comp, children(clsName), getRequiredFieldsRec(comp))) + } else None + ) + .getOrElse(None) + } + + def classHierarchy(cls: String, model: Tracker[Schema[_]]): Target[Option[ClassParent[ScalaLanguage]]] = + model + .refine { case c: ComposedSchema => c }(c => + firstInHierarchy(c) + .fold(Option.empty[Discriminator[ScalaLanguage]].pure[Target])(Discriminator.fromSchema[ScalaLanguage, Target]) + .map(_.map((_, getRequiredFieldsRec(c)))) + ) + .orRefine { case x: Schema[_] => x }(m => Discriminator.fromSchema(m).map(_.map((_, getRequiredFieldsRec(m))))) + .getOrElse(Option.empty[(Discriminator[ScalaLanguage], List[String])].pure[Target]) + .map(_.map { case (discriminator, reqFields) => ClassParent(cls, model, children(cls), discriminator, reqFields) }) + + Sw.log.function("groupHierarchies")( + definitions.value + .traverse { case (cls, model) => + for { + hierarchy <- classHierarchy(cls, model) + } yield hierarchy.filterNot(_.children.isEmpty).toLeft((cls, model)) + } + .map(_.partitionEither[ClassParent[ScalaLanguage], (String, Tracker[Schema[_]])](identity)) + ) + } + + private def defaultValue( + name: NonEmptyList[String], + schema: Tracker[Schema[_]], + requirement: PropertyRequirement, + definitions: List[(String, Tracker[Schema[_]])] + )(implicit + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target] + ): Target[Option[scala.meta.Term]] = { + import Sc._ + import Cl._ + val empty = Option.empty[scala.meta.Term].pure[Target] + schema.downField("$ref", _.get$ref()).indexedDistribute match { + case Some(ref) => + definitions + .collectFirst { + case (cls, refSchema) if ref.unwrapTracker.endsWith(s"/$cls") => + defaultValue(NonEmptyList.of(cls), refSchema, requirement, definitions) + } + .getOrElse(empty) + case None => + schema + .refine { case map: MapSchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => map }(map => + for { + customTpe <- SwaggerUtil.customMapTypeName(map) + result <- customTpe.fold(emptyMap().map(Option(_)))(_ => empty) + } yield result + ) + .orRefine { case arr: ArraySchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => arr }( + arr => + for { + customTpe <- SwaggerUtil.customArrayTypeName(arr) + result <- customTpe.fold(emptyArray().map(Option(_)))(_ => empty) + } yield result + ) + .orRefine { case p: BooleanSchema => p }(p => Default(p).extract[Boolean].fold(empty)(litBoolean(_).map(Some(_)))) + .orRefine { case p: NumberSchema if p.getFormat == "double" => p }(p => Default(p).extract[Double].fold(empty)(litDouble(_).map(Some(_)))) + .orRefine { case p: NumberSchema if p.getFormat == "float" => p }(p => Default(p).extract[Float].fold(empty)(litFloat(_).map(Some(_)))) + .orRefine { case p: IntegerSchema if p.getFormat == "int32" => p }(p => Default(p).extract[Int].fold(empty)(litInt(_).map(Some(_)))) + .orRefine { case p: IntegerSchema if p.getFormat == "int64" => p }(p => Default(p).extract[Long].fold(empty)(litLong(_).map(Some(_)))) + .orRefine { case p: StringSchema if Option(p.getEnum).map(_.asScala).exists(_.nonEmpty) => p }(p => + Default(p).extract[String] match { + case Some(defaultEnumValue) => + for { + enumName <- formatEnumName(defaultEnumValue) + result <- selectTerm(name.append(enumName)) + } yield Some(result) + case None => empty + } + ) + .orRefine { case p: StringSchema => p }(p => Default(p).extract[String].fold(empty)(litString(_).map(Some(_)))) + .getOrElse(empty) + } + } + + private def suffixClsName(prefix: String, clsName: String): Pat.Var = Pat.Var(Term.Name(s"${prefix}${clsName}")) + + private def lookupTypeName(tpeName: String, concreteTypes: List[PropMeta[ScalaLanguage]])(f: Type => Type): Option[Type] = + concreteTypes + .find(_.clsName == tpeName) + .map(_.tpe) + .map(f) + + private def renderMembers(clsName: String, elems: RenderedEnum[ScalaLanguage]) = { + val fields = elems match { + case RenderedStringEnum(elems) => + elems.map { case (value, termName, _) => + (termName, Lit.String(value)) + } + case RenderedIntEnum(elems) => + elems.map { case (value, termName, _) => + (termName, Lit.Int(value)) + } + case RenderedLongEnum(elems) => + elems.map { case (value, termName, _) => + (termName, Lit.Long(value)) + } + } + + Target.pure(Some(q""" + object members { + ..${fields.map { case (termName, lit) => q"""case object ${termName} extends ${Type.Name(clsName)}(${lit})""" }} + } + """)) + } + + private def encodeEnum(clsName: String, tpe: Type): Target[Option[Defn]] = + Target.pure(Some(q""" + implicit val ${suffixClsName("encode", clsName)}: _root_.io.circe.Encoder[${Type.Name(clsName)}] = + _root_.io.circe.Encoder[${tpe}].contramap(_.value) + """)) + + private def decodeEnum(clsName: String, tpe: Type): Target[Option[Defn]] = + Target.pure(Some(q""" + implicit val ${suffixClsName("decode", clsName)}: _root_.io.circe.Decoder[${Type.Name(clsName)}] = + _root_.io.circe.Decoder[${tpe}].emap(value => from(value).toRight(${Term + .Interpolate(Term.Name("s"), List(Lit.String(""), Lit.String(s" not a member of ${clsName}")), List(Term.Name("value")))})) + """)) + + private def renderClass(clsName: String, tpe: scala.meta.Type, elems: RenderedEnum[ScalaLanguage]) = + Target.pure(q""" + sealed abstract class ${Type.Name(clsName)}(val value: ${tpe}) extends _root_.scala.Product with _root_.scala.Serializable { + override def toString: String = value.toString + } + """) + + private def renderStaticDefns( + clsName: String, + tpe: scala.meta.Type, + members: Option[scala.meta.Defn.Object], + accessors: List[scala.meta.Term.Name], + encoder: Option[scala.meta.Defn], + decoder: Option[scala.meta.Defn] + ): Target[StaticDefns[ScalaLanguage]] = { + val longType = Type.Name(clsName) + val terms: List[Defn.Val] = accessors.map { pascalValue => + q"val ${Pat.Var(pascalValue)}: ${longType} = members.${pascalValue}" + }.toList + val values: Defn.Val = q"val values = _root_.scala.Vector(..$accessors)" + val implicits: List[Defn.Val] = List( + q"implicit val ${Pat.Var(Term.Name(s"show${clsName}"))}: Show[${longType}] = Show[${tpe}].contramap[${longType}](_.value)" + ) + Target.pure( + StaticDefns[ScalaLanguage]( + className = clsName, + extraImports = List.empty[Import], + definitions = members.toList ++ + terms ++ + List(Some(values), encoder, decoder).flatten ++ + implicits ++ + List( + q"def from(value: ${tpe}): _root_.scala.Option[${longType}] = values.find(_.value == value)", + q"implicit val order: cats.Order[${longType}] = cats.Order.by[${longType}, Int](values.indexOf)" + ) + ) + ) + } + + override def buildAccessor(clsName: String, termName: String) = + Target.pure(q"${Term.Name(clsName)}.${Term.Name(termName)}") + + private def extractProperties(swagger: Tracker[Schema[_]]) = + swagger + .refine[Target[List[(String, Tracker[Schema[_]])]]] { case o: ObjectSchema => o }(m => + Target.pure(m.downField("properties", _.getProperties()).indexedCosequence.value) + ) + .orRefine { case c: ComposedSchema => c } { comp => + val extractedProps = + comp + .downField("allOf", _.getAllOf()) + .indexedDistribute + .flatMap(_.downField("properties", _.getProperties).indexedCosequence.value) + Target.pure(extractedProps) + } + .orRefine { case x: Schema[_] if Option(x.get$ref()).isDefined => x }(comp => + Target.raiseUserError(s"Attempted to extractProperties for a ${comp.unwrapTracker.getClass()}, unsure what to do here (${comp.showHistory})") + ) + .getOrElse(Target.pure(List.empty[(String, Tracker[Schema[_]])])) + + private def transformProperty( + clsName: String, + dtoPackage: List[String], + supportPackage: List[String], + concreteTypes: List[PropMeta[ScalaLanguage]] )( + name: String, + fieldName: String, + property: Tracker[Schema[_]], + meta: core.ResolvedType[ScalaLanguage], + requirement: PropertyRequirement, + isCustomType: Boolean, + defaultValue: Option[scala.meta.Term] + ): Target[ProtocolParameter[ScalaLanguage]] = + Target.log.function(s"transformProperty") { + val fallbackRawType = ReifiedRawType.of(property.downField("type", _.getType()).unwrapTracker, property.downField("format", _.getFormat()).unwrapTracker) + for { + _ <- Target.log.debug(s"Args: (${clsName}, ${name}, ...)") + + readOnlyKey = Option(name).filter(_ => property.downField("readOnly", _.getReadOnly()).unwrapTracker.contains(true)) + emptyToNull = property + .refine { case d: DateSchema => d }(d => EmptyValueIsNull(d)) + .orRefine { case dt: DateTimeSchema => dt }(dt => EmptyValueIsNull(dt)) + .orRefine { case s: StringSchema => s }(s => EmptyValueIsNull(s)) + .toOption + .flatten + .getOrElse(EmptyIsEmpty) + + dataRedaction = DataRedaction(property).getOrElse(DataVisible) + + (tpe, classDep, rawType) <- meta match { + case core.Resolved(declType, classDep, _, rawType @ LiteralRawType(Some(rawTypeStr), rawFormat)) + if SwaggerUtil.isFile(rawTypeStr, rawFormat) && !isCustomType => + // assume that binary data are represented as a string. allow users to override. + Target.pure((t"String", classDep, rawType)) + case core.Resolved(declType, classDep, _, rawType) => + for { + validatedType <- applyValidations(clsName, declType, property) + } yield (validatedType, classDep, rawType) + case core.Deferred(tpeName) => + val tpe = concreteTypes.find(_.clsName == tpeName).map(_.tpe).getOrElse { + println(s"Unable to find definition for ${tpeName}, just inlining") + Type.Name(tpeName) + } + for { + validatedType <- applyValidations(clsName, tpe, property) + } yield (validatedType, Option.empty, fallbackRawType) + case core.DeferredArray(tpeName, containerTpe) => + val concreteType = lookupTypeName(tpeName, concreteTypes)(identity) + val innerType = concreteType.getOrElse(Type.Name(tpeName)) + val tpe = t"${containerTpe.getOrElse(t"_root_.scala.Vector")}[$innerType]" + for { + validatedType <- applyValidations(clsName, tpe, property) + } yield (validatedType, Option.empty, ReifiedRawType.ofVector(fallbackRawType)) + case core.DeferredMap(tpeName, customTpe) => + val concreteType = lookupTypeName(tpeName, concreteTypes)(identity) + val innerType = concreteType.getOrElse(Type.Name(tpeName)) + val tpe = t"${customTpe.getOrElse(t"_root_.scala.Predef.Map")}[_root_.scala.Predef.String, $innerType]" + for { + validatedType <- applyValidations(clsName, tpe, property) + } yield (validatedType, Option.empty, ReifiedRawType.ofMap(fallbackRawType)) + } + fieldPattern: Tracker[Option[String]] = property.downField("pattern", _.getPattern) + collectionElementPattern: Option[Tracker[String]] = + property.downField("items", _.getItems).indexedDistribute.flatMap(_.downField("pattern", _.getPattern).indexedDistribute) + + pattern = collectionElementPattern.fold(fieldPattern.map(PropertyValidations))( + _.map(regex => PropertyValidations(Some(regex))) + ) + + presence <- ScalaGenerator().selectTerm(NonEmptyList.ofInitLast(supportPackage, "Presence")) + presenceType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage, "Presence")) + (finalDeclType, finalDefaultValue) = requirement match { + case PropertyRequirement.Required => tpe -> defaultValue + case PropertyRequirement.Optional | PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) => + t"$presenceType[$tpe]" -> defaultValue.map(t => q"$presence.Present($t)").orElse(Some(q"$presence.Absent")) + case _: PropertyRequirement.OptionalRequirement | _: PropertyRequirement.Configured => + t"Option[$tpe]" -> defaultValue.map(t => q"Option($t)").orElse(Some(q"None")) + case PropertyRequirement.OptionalNullable => + t"$presenceType[Option[$tpe]]" -> defaultValue.map(t => q"$presence.Present($t)") + } + term = param"${Term.Name(fieldName)}: ${finalDeclType}".copy(default = finalDefaultValue) + dep = classDep.filterNot(_.value == clsName) // Filter out our own class name + } yield ProtocolParameter[ScalaLanguage]( + term, + tpe, + RawParameterName(name), + dep, + rawType, + readOnlyKey, + emptyToNull, + dataRedaction, + requirement, + finalDefaultValue, + pattern + ) + } + + private def renderDTOClass( + clsName: String, + supportPackage: List[String], + selfParams: List[ProtocolParameter[ScalaLanguage]], + parents: List[SuperClass[ScalaLanguage]] = Nil + ) = { + val discriminators = parents.flatMap(_.discriminators) + val discriminatorNames = discriminators.map(_.propertyName).toSet + val parentOpt = if (parents.exists(s => s.discriminators.nonEmpty)) { + parents.headOption + } else { + None + } + val params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) + + val terms = params.map(_.term) + + val toStringMethod = if (params.exists(_.dataRedaction != DataVisible)) { + def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match { + case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()" + case _ => Lit.String("[redacted]") + } + + val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(","))) + + List[Defn.Def]( + q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}" + ) + } else { + List.empty[Defn.Def] + } + + val code = parentOpt + .fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent => + q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a => + init"${Type.Name(a)}(...$Nil)" + )} { ..$toStringMethod }""" + ) + + Target.pure(code) + } + + private def encodeModel( + clsName: String, + dtoPackage: List[String], + selfParams: List[ProtocolParameter[ScalaLanguage]], + parents: List[SuperClass[ScalaLanguage]] = Nil + ) = { + val discriminators = parents.flatMap(_.discriminators) + val discriminatorNames = discriminators.map(_.propertyName).toSet + val allParams = parents.reverse.flatMap(_.params) ++ selfParams + val (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value)) + val readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList + val typeName = Type.Name(clsName) + val encVal = { + def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) = + q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))""" + + def encodeRequired(param: ProtocolParameter[ScalaLanguage]) = + q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)""" + + def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = { + val name = Lit.String(param.name.value) + q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))" + } + + val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param => + val name = Lit.String(param.name.value) + param.propertyRequirement match { + case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy => + Right(encodeRequired(param)) + case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable => + Left(encodeOptional(param)) + case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) => + Left(encodeOptional(param)) + case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) => + Right(encodeRequired(param)) + case PropertyRequirement.Configured(PropertyRequirement.Optional, _) => + Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""") + } + } + + val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName)) + val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})" + val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) => + q"$acc ++ $field" + } + + q""" + ${circeVersion.encoderObjectCompanion}.instance[${Type.Name(clsName)}](a => _root_.io.circe.JsonObject.fromIterable($allFields)) + """ + } + val (readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys => + ( + List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"), + encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))" + ) + } + + Target.pure(Option(q""" + implicit val ${suffixClsName("encode", clsName)}: ${circeVersion.encoderObject}[${Type.Name(clsName)}] = { + ..${readOnlyDefn}; + ${readOnlyFilter(encVal)} + } + """)) + } + + private def decodeModel( + clsName: String, + dtoPackage: List[String], + supportPackage: List[String], + selfParams: List[ProtocolParameter[ScalaLanguage]], + parents: List[SuperClass[ScalaLanguage]] = Nil + ): Target[Option[Defn.Val]] = { + val discriminators = parents.flatMap(_.discriminators) + val discriminatorNames = discriminators.map(_.propertyName).toSet + val allParams = parents.reverse.flatMap(_.params) ++ selfParams + val params = allParams.filterNot(param => discriminatorNames.contains(param.name.value)) + val needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull) + val paramCount = params.length + for { + presence <- ScalaGenerator().selectTerm(NonEmptyList.ofInitLast(supportPackage, "Presence")) + decVal <- + if (paramCount == 0) { + Target.pure( + q""" + new _root_.io.circe.Decoder[${Type.Name(clsName)}] { + final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${Type.Name(clsName)}] = + _root_.scala.Right(${Term.Name(clsName)}()) + } + """ + ) + } else { + params.zipWithIndex + .traverse { case (param, idx) => + for { + rawTpe <- Target.fromOption(param.term.decltpe, UserError("Missing type")) + tpe <- rawTpe match { + case tpe: Type => Target.pure(tpe) + case x => Target.raiseUserError(s"Unsure how to map ${x.structure}, please report this bug!") + } + } yield { + val term = Term.Name(s"v$idx") + val name = Lit.String(param.name.value) + + val emptyToNull: Term => Term = if (param.emptyToNull == EmptyIsNull) { t => + q"$t.withFocus(j => j.asString.fold(j)(s => if(s.isEmpty) _root_.io.circe.Json.Null else j))" + } else identity _ + + val decodeField: Type => NonEmptyVector[Term => Term] = { tpe => + NonEmptyVector.of[Term => Term]( + t => q"$t.downField($name)", + emptyToNull, + t => q"$t.as[${tpe}]" + ) + } + + val decodeOptionalField: Type => (Term => Term, Term) => NonEmptyVector[Term => Term] = { tpe => (present, absent) => + NonEmptyVector.of[Term => Term](t => q""" + ((c: _root_.io.circe.HCursor) => + c + .value + .asObject + .filter(!_.contains($name)) + .fold(${emptyToNull(q"c.downField($name)")}.as[${tpe}].map(x => ${present(q"x")})) { _ => + _root_.scala.Right($absent) + } + )($t) + """) + } + + def decodeOptionalRequirement( + param: ProtocolParameter[ScalaLanguage] + ): PropertyRequirement.OptionalRequirement => NonEmptyVector[Term => Term] = { + case PropertyRequirement.OptionalLegacy => + decodeField(tpe) + case PropertyRequirement.RequiredNullable => + decodeField(t"_root_.io.circe.Json") :+ (t => q"$t.flatMap(_.as[${tpe}])") + case PropertyRequirement.Optional => // matched only where there is inconsistency between encoder and decoder + decodeOptionalField(param.baseType)(x => q"Option($x)", q"None") + } + + val parseTermAccessors: NonEmptyVector[Term => Term] = param.propertyRequirement match { + case PropertyRequirement.Required => + decodeField(tpe) + case PropertyRequirement.OptionalNullable => + decodeOptionalField(t"Option[${param.baseType}]")(x => q"$presence.present($x)", q"$presence.absent") + case PropertyRequirement.Optional | PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) => + decodeOptionalField(param.baseType)(x => q"$presence.present($x)", q"$presence.absent") + case requirement: PropertyRequirement.OptionalRequirement => + decodeOptionalRequirement(param)(requirement) + case PropertyRequirement.Configured(_, decoderRequirement) => + decodeOptionalRequirement(param)(decoderRequirement) + } + + val parseTerm = parseTermAccessors.foldLeft[Term](q"c")((acc, next) => next(acc)) + val _enum = enumerator"""${Pat.Var(term)} <- $parseTerm""" + (term, _enum) + } + } + .map { pairs => + val (terms, enumerators) = pairs.unzip + q""" + new _root_.io.circe.Decoder[${Type.Name(clsName)}] { + final def apply(c: _root_.io.circe.HCursor): _root_.io.circe.Decoder.Result[${Type.Name(clsName)}] = + for { + ..${enumerators} + } yield ${Term.Name(clsName)}(..${terms}) + } + """ + } + } + } yield Option(q""" + implicit val ${suffixClsName("decode", clsName)}: _root_.io.circe.Decoder[${Type.Name(clsName)}] = $decVal + """) + } + + private def renderDTOStaticDefns( clsName: String, deps: List[scala.meta.Term.Name], encoder: Option[scala.meta.Defn.Val], @@ -128,37 +1366,234 @@ object CirceRefinedProtocolGenerator { // to avoid declaring the same type alias multiple times val deduplicatedRegexHelperTypes = regexHelperTypes.groupBy(_.structure).values.flatMap(_.headOption).toList + val extraImports: List[Import] = deps.map { term => + q"import ${term}._" + } + + Target.pure( + StaticDefns[ScalaLanguage]( + className = clsName, + extraImports = extraImports, + definitions = (deduplicatedRegexHelperTypes ++ encoder ++ decoder).toList + ) + ) + } + + private def extractArrayType(arr: core.ResolvedType[ScalaLanguage], concreteTypes: List[PropMeta[ScalaLanguage]]) = for { - defns <- base.renderDTOStaticDefns(clsName, deps, encoder, decoder, protocolParameters) - } yield defns.copy(definitions = deduplicatedRegexHelperTypes ++ defns.definitions) - } - - def fromGenerator(generator: ProtocolTerms[ScalaLanguage, Target]): ProtocolTerms[ScalaLanguage, Target] = - generator.copy( - protocolImports = { () => - generator - .protocolImports() - .map(imports => - imports ++ List( - q"import io.circe.refined._", - q"import eu.timepit.refined.api.Refined", - q"import eu.timepit.refined.auto._" - ) + result <- arr match { + case core.Resolved(tpe, dep, default, _) => Target.pure(tpe) + case core.Deferred(tpeName) => + Target.fromOption(lookupTypeName(tpeName, concreteTypes)(identity), UserError(s"Unresolved reference ${tpeName}")) + case core.DeferredArray(tpeName, containerTpe) => + Target.fromOption( + lookupTypeName(tpeName, concreteTypes)(tpe => t"${containerTpe.getOrElse(t"_root_.scala.Vector")}[${tpe}]"), + UserError(s"Unresolved reference ${tpeName}") + ) + case core.DeferredMap(tpeName, customTpe) => + Target.fromOption( + lookupTypeName(tpeName, concreteTypes)(tpe => + t"_root_.scala.Vector[${customTpe.getOrElse(t"_root_.scala.Predef.Map")}[_root_.scala.Predef.String, ${tpe}]]" + ), + UserError(s"Unresolved reference ${tpeName}") ) - }, - renderDTOStaticDefns = renderDTOStaticDefns(generator) _, - staticProtocolImports = pkgName => { - val implicitsRef: Term.Ref = (pkgName.map(Term.Name.apply _) ++ List(q"Implicits")).foldLeft[Term.Ref](q"_root_")(Term.Select.apply _) - Target.pure( - List( - q"import cats.implicits._", - q"import cats.data.EitherT", - q"import io.circe.refined._", - q"import eu.timepit.refined.api.Refined", - q"import eu.timepit.refined.auto._" - ) :+ q"import $implicitsRef._" - ) } + } yield result + + private def extractConcreteTypes(definitions: Either[String, List[PropMeta[ScalaLanguage]]]) = + definitions.fold[Target[List[PropMeta[ScalaLanguage]]]](Target.raiseUserError _, Target.pure _) + + private def protocolImports() = + Target.pure( + List( + q"import cats.syntax.either._", + q"import io.circe.syntax._", + q"import cats.instances.all._", + q"import io.circe.refined._", + q"import eu.timepit.refined.api.Refined", + q"import eu.timepit.refined.auto._" + ) + ) + + override def staticProtocolImports(pkgName: List[String]): Target[List[Import]] = { + val implicitsRef: Term.Ref = (pkgName.map(Term.Name.apply _) ++ List(q"Implicits")).foldLeft[Term.Ref](q"_root_")(Term.Select.apply _) + Target.pure( + List( + q"import cats.implicits._", + q"import cats.data.EitherT", + q"import io.circe.refined._", + q"import eu.timepit.refined.api.Refined", + q"import eu.timepit.refined.auto._", + q"import $implicitsRef._" + ) + ) + } + + private def packageObjectImports() = + Target.pure(List.empty) + + override def generateSupportDefinitions() = { + val presenceTrait = + q"""sealed trait Presence[+T] extends _root_.scala.Product with _root_.scala.Serializable { + def fold[R](ifAbsent: => R, + ifPresent: T => R): R + def map[R](f: T => R): Presence[R] = fold(Presence.absent, a => Presence.present(f(a))) + + def toOption: Option[T] = fold[Option[T]](None, Some(_)) + } + """ + val presenceObject = + q""" + object Presence { + def absent[R]: Presence[R] = Absent + def present[R](value: R): Presence[R] = Present(value) + case object Absent extends Presence[Nothing] { + def fold[R](ifAbsent: => R, + ifValue: Nothing => R): R = ifAbsent + } + final case class Present[+T](value: T) extends Presence[T] { + def fold[R](ifAbsent: => R, + ifPresent: T => R): R = ifPresent(value) + } + + def fromOption[T](value: Option[T]): Presence[T] = + value.fold[Presence[T]](Absent)(Present(_)) + + implicit object PresenceFunctor extends cats.Functor[Presence] { + def map[A, B](fa: Presence[A])(f: A => B): Presence[B] = fa.fold[Presence[B]](Presence.absent, a => Presence.present(f(a))) + } + } + """ + val presenceDefinition = SupportDefinition[ScalaLanguage](q"Presence", Nil, List(presenceTrait, presenceObject), insideDefinitions = false) + Target.pure(List(presenceDefinition)) + } + + private def packageObjectContents() = + Target.pure( + List( + q"implicit val guardrailDecodeInstant: _root_.io.circe.Decoder[java.time.Instant] = _root_.io.circe.Decoder[java.time.Instant].or(_root_.io.circe.Decoder[_root_.scala.Long].map(java.time.Instant.ofEpochMilli))", + q"implicit val guardrailDecodeLocalDate: _root_.io.circe.Decoder[java.time.LocalDate] = _root_.io.circe.Decoder[java.time.LocalDate].or(_root_.io.circe.Decoder[java.time.Instant].map(_.atZone(java.time.ZoneOffset.UTC).toLocalDate))", + q"implicit val guardrailDecodeLocalDateTime: _root_.io.circe.Decoder[java.time.LocalDateTime] = _root_.io.circe.Decoder[java.time.LocalDateTime]", + q"implicit val guardrailDecodeLocalTime: _root_.io.circe.Decoder[java.time.LocalTime] = _root_.io.circe.Decoder[java.time.LocalTime]", + q"implicit val guardrailDecodeOffsetDateTime: _root_.io.circe.Decoder[java.time.OffsetDateTime] = _root_.io.circe.Decoder[java.time.OffsetDateTime].or(_root_.io.circe.Decoder[java.time.Instant].map(_.atZone(java.time.ZoneOffset.UTC).toOffsetDateTime))", + q"implicit val guardrailDecodeZonedDateTime: _root_.io.circe.Decoder[java.time.ZonedDateTime] = _root_.io.circe.Decoder[java.time.ZonedDateTime]", + q"implicit val guardrailDecodeBase64String: _root_.io.circe.Decoder[Base64String] = _root_.io.circe.Decoder[_root_.scala.Predef.String].emapTry(v => scala.util.Try(java.util.Base64.getDecoder.decode(v))).map(new Base64String(_))", + q"implicit val guardrailEncodeInstant: _root_.io.circe.Encoder[java.time.Instant] = _root_.io.circe.Encoder[java.time.Instant]", + q"implicit val guardrailEncodeLocalDate: _root_.io.circe.Encoder[java.time.LocalDate] = _root_.io.circe.Encoder[java.time.LocalDate]", + q"implicit val guardrailEncodeLocalDateTime: _root_.io.circe.Encoder[java.time.LocalDateTime] = _root_.io.circe.Encoder[java.time.LocalDateTime]", + q"implicit val guardrailEncodeLocalTime: _root_.io.circe.Encoder[java.time.LocalTime] = _root_.io.circe.Encoder[java.time.LocalTime]", + q"implicit val guardrailEncodeOffsetDateTime: _root_.io.circe.Encoder[java.time.OffsetDateTime] = _root_.io.circe.Encoder[java.time.OffsetDateTime]", + q"implicit val guardrailEncodeZonedDateTime: _root_.io.circe.Encoder[java.time.ZonedDateTime] = _root_.io.circe.Encoder[java.time.ZonedDateTime]", + q"implicit val guardrailEncodeBase64String: _root_.io.circe.Encoder[Base64String] = _root_.io.circe.Encoder[_root_.scala.Predef.String].contramap[Base64String](v => new _root_.scala.Predef.String(java.util.Base64.getEncoder.encode(v.data)))" + ) + ) + + private def implicitsObject() = Target.pure(None) + + private def extractSuperClass( + swagger: Tracker[ComposedSchema], + definitions: List[(String, Tracker[Schema[_]])] + ) = { + def allParents: Tracker[Schema[_]] => Target[List[(String, Tracker[Schema[_]], List[Tracker[Schema[_]]])]] = + _.refine[Target[List[(String, Tracker[Schema[_]], List[Tracker[Schema[_]]])]]] { case x: ComposedSchema => x }( + _.downField("allOf", _.getAllOf()).indexedDistribute.filter(_.downField("$ref", _.get$ref()).unwrapTracker.nonEmpty) match { + case head :: tail => + definitions + .collectFirst { + case (clsName, e) if head.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$clsName")) => + val thisParent = (clsName, e, tail) + allParents(e).map(otherParents => thisParent :: otherParents) + } + .getOrElse( + Target.raiseUserError(s"Reference ${head.downField("$ref", _.get$ref()).unwrapTracker} not found among definitions (${head.showHistory})") + ) + case _ => Target.pure(List.empty) + } + ).getOrElse(Target.pure(List.empty)) + allParents(swagger) + } + + private def renderADTStaticDefns( + clsName: String, + discriminator: Discriminator[ScalaLanguage], + encoder: Option[scala.meta.Defn.Val], + decoder: Option[scala.meta.Defn.Val] + ) = + Target.pure( + StaticDefns[ScalaLanguage]( + className = clsName, + extraImports = List.empty[Import], + definitions = List[Option[Defn]]( + Some(q"val discriminator: String = ${Lit.String(discriminator.propertyName)}"), + encoder, + decoder + ).flatten + ) ) + private def decodeADT(clsName: String, discriminator: Discriminator[ScalaLanguage], children: List[String] = Nil) = { + val (childrenCases, childrenDiscriminators) = children.map { child => + val discriminatorValue = discriminator.mapping + .collectFirst { case (value, elem) if elem.name == child => value } + .getOrElse(child) + ( + p"case ${Lit.String(discriminatorValue)} => c.as[${Type.Name(child)}]", + discriminatorValue + ) + }.unzip + val code = + q"""implicit val decoder: _root_.io.circe.Decoder[${Type.Name(clsName)}] = _root_.io.circe.Decoder.instance({ c => + val discriminatorCursor = c.downField(discriminator) + discriminatorCursor.as[String].flatMap { + ..case $childrenCases; + case tpe => + _root_.scala.Left(_root_.io.circe.DecodingFailure("Unknown value " ++ tpe ++ ${Lit + .String(s" (valid: ${childrenDiscriminators.mkString(", ")})")}, discriminatorCursor.history)) + } + })""" + Target.pure(Some(code)) + } + + private def encodeADT(clsName: String, discriminator: Discriminator[ScalaLanguage], children: List[String] = Nil) = { + val childrenCases = children.map { child => + val discriminatorValue = discriminator.mapping + .collectFirst { case (value, elem) if elem.name == child => value } + .getOrElse(child) + p"case e:${Type.Name(child)} => e.asJsonObject.add(discriminator, _root_.io.circe.Json.fromString(${Lit.String(discriminatorValue)})).asJson" + } + val code = + q"""implicit val encoder: _root_.io.circe.Encoder[${Type.Name(clsName)}] = _root_.io.circe.Encoder.instance { + ..case $childrenCases + }""" + Target.pure(Some(code)) + } + + private def renderSealedTrait( + className: String, + params: List[ProtocolParameter[ScalaLanguage]], + discriminator: Discriminator[ScalaLanguage], + parents: List[SuperClass[ScalaLanguage]] = Nil, + children: List[String] = Nil + ) = + for { + testTerms <- + params + .map(_.term) + .filter(_.name.value != discriminator.propertyName) + .traverse { t => + for { + tpe <- Target.fromOption( + t.decltpe + .flatMap { + case tpe: Type => Some(tpe) + case x => None + }, + UserError(t.decltpe.fold("Nothing to map")(x => s"Unsure how to map ${x.structure}, please report this bug!")) + ) + } yield q"""def ${Term.Name(t.name.value)}: ${tpe}""" + } + } yield parents.headOption + .fold(q"""trait ${Type.Name(className)} {..${testTerms}}""")(parent => + q"""trait ${Type.Name(className)} extends ${init"${Type.Name(parent.clsName)}(...$Nil)"} { ..${testTerms} } """ + ) } diff --git a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/jackson/JacksonProtocolGenerator.scala b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/jackson/JacksonProtocolGenerator.scala index da1dea67e3..fcefed7d30 100644 --- a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/jackson/JacksonProtocolGenerator.scala +++ b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/jackson/JacksonProtocolGenerator.scala @@ -1,19 +1,42 @@ package dev.guardrail.generators.scala.jackson +import _root_.io.swagger.v3.oas.models.media.{ Discriminator => _, _ } +import _root_.io.swagger.v3.oas.models.{ Components, OpenAPI } +import cats.Foldable +import cats.Monad import cats.data.NonEmptyList import cats.syntax.all._ - +import scala.jdk.CollectionConverters._ import scala.meta._ import scala.reflect.runtime.universe.typeTag -import dev.guardrail.core.{ EmptyIsNull, SupportDefinition } -import dev.guardrail.generators.scala.{ CirceModelGenerator, JacksonModelGenerator, ScalaGenerator, ScalaLanguage } -import dev.guardrail.generators.scala.circe.CirceProtocolGenerator + +import dev.guardrail.core +import dev.guardrail.core.extract.{ DataRedaction, Default, EmptyValueIsNull } +import dev.guardrail.core.{ DataRedacted, DataVisible, EmptyIsEmpty, EmptyIsNull, LiteralRawType, Mappish, ReifiedRawType, SupportDefinition, Tracker } +import dev.guardrail.generators.ProtocolGenerator.{ WrapEnumSchema, wrapNumberEnumSchema, wrapObjectEnumSchema, wrapStringEnumSchema } +import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassParent } +import dev.guardrail.generators.scala.{ JacksonModelGenerator, ScalaGenerator, ScalaLanguage } import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader } +import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName } +import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.PropertyRequirement.{ Optional, RequiredNullable } import dev.guardrail.terms.protocol._ import dev.guardrail.terms.protocol.{ Discriminator, PropertyRequirement } -import dev.guardrail.terms.ProtocolTerms -import dev.guardrail.{ RuntimeFailure, Target } +import dev.guardrail.terms.{ + CollectionsLibTerms, + HeldEnum, + IntHeldEnum, + LanguageTerms, + LongHeldEnum, + ProtocolTerms, + RenderedEnum, + RenderedIntEnum, + RenderedLongEnum, + RenderedStringEnum, + StringHeldEnum, + SwaggerTerms +} +import dev.guardrail.{ RuntimeFailure, SwaggerUtil, Target, UserError } class JacksonProtocolGeneratorLoader extends ProtocolGeneratorLoader { type L = ScalaLanguage @@ -25,11 +48,23 @@ class JacksonProtocolGeneratorLoader extends ProtocolGeneratorLoader { } object JacksonProtocolGenerator { + def apply: ProtocolTerms[ScalaLanguage, Target] = + new JacksonProtocolGenerator +} + +class JacksonProtocolGenerator private extends ProtocolTerms[ScalaLanguage, Target] { + + override implicit def MonadF: Monad[Target] = Target.targetInstances + + import Target.targetInstances // TODO: Remove me. This resolves implicit ambiguity from MonadChain + private def discriminatorValue(discriminator: Discriminator[ScalaLanguage], className: String): String = discriminator.mapping .collectFirst { case (value, elem) if elem.name == className => value } .getOrElse(className) + private val jsonIgnoreProperties = mod"""@com.fasterxml.jackson.annotation.JsonIgnoreProperties(ignoreUnknown = true)""" + private def paramAnnotations( param: ProtocolParameter[ScalaLanguage], presenceSerType: Type, @@ -121,334 +156,1227 @@ object JacksonProtocolGenerator { case _ => param.term.default } - private val jsonIgnoreProperties = mod"""@com.fasterxml.jackson.annotation.JsonIgnoreProperties(ignoreUnknown = true)""" + override def fromSwagger( + swagger: Tracker[OpenAPI], + dtoPackage: List[String], + supportPackage: NonEmptyList[String], + defaultPropertyRequirement: PropertyRequirement + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolDefinitions[ScalaLanguage]] = { + import Sc._ - def apply: ProtocolTerms[ScalaLanguage, Target] = { - val baseInterp = CirceProtocolGenerator(CirceModelGenerator.V012) - import Target.targetInstances + val components = swagger.downField("components", _.getComponents()) + val definitions = components.flatDownField("schemas", _.getSchemas()).indexedCosequence + Sw.log.function("ProtocolGenerator.fromSwagger")(for { + (hierarchies, definitionsWithoutPoly) <- groupHierarchies(definitions) - baseInterp.copy( - renderDTOClass = (className, supportPackage, terms, parents) => - for { - renderedClass <- baseInterp.renderDTOClass(className, supportPackage, terms, parents) - discriminatorParams = parents.flatMap(parent => - parent.discriminators.flatMap(discrim => parent.params.find(_.name.value == discrim.propertyName).map((discrim, _))) + concreteTypes <- SwaggerUtil.extractConcreteTypes[ScalaLanguage, Target](definitions.value, components) + polyADTs <- hierarchies.traverse(fromPoly(_, concreteTypes, definitions.value, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components)) + elems <- definitionsWithoutPoly.traverse { case (clsName, model) => + model + .refine { case c: ComposedSchema => c }(comp => + for { + formattedClsName <- formatTypeName(clsName) + parents <- extractParents(comp, definitions.value, concreteTypes, dtoPackage, supportPackage.toList, defaultPropertyRequirement, components) + model <- fromModel( + clsName = NonEmptyList.of(formattedClsName), + model = comp, + parents = parents, + concreteTypes = concreteTypes, + definitions = definitions.value, + dtoPackage = dtoPackage, + supportPackage = supportPackage.toList, + defaultPropertyRequirement = defaultPropertyRequirement, + components = components + ) + alias <- modelTypeAlias(formattedClsName, comp, components) + } yield model.getOrElse(alias) ) - discriminators <- discriminatorParams.traverse { case (discriminator, param) => + .orRefine { case a: ArraySchema => a }(arr => for { - discrimTpe <- Target.fromOption(param.term.decltpe, RuntimeFailure(s"Property ${param.name.value} has no type")) - discrimValue <- JacksonHelpers - .discriminatorExpression[ScalaLanguage]( - param.name.value, - discriminatorValue(discriminator, className), - param.rawType - )( - v => Target.pure[Term](q"""BigInt(${Lit.String(v)})"""), - v => Target.pure[Term](q"""BigDecimal(${Lit.String(v)})"""), - v => - param.term.decltpe.fold( - Target.raiseUserError[Term](s"No declared type for property '${param.name.value}' on class $className") - ) { - case tpe @ (_: Type.Name | _: Type.Select) => - ScalaGenerator().formatEnumName(v).map(ev => Term.Select(Term.Name(tpe.toString), Term.Name(ev))) - case tpe => Target.raiseError(RuntimeFailure(s"Assumed property ${param.name.value} was an enum, but can't handle $tpe")) - } - )(ScalaGenerator()) - } yield (param.name.value, param.term.name.value, param.term.decltpe, discrimValue) - } - presenceSerType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage :+ "Presence", "PresenceSerializer")) - presenceDeserType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage :+ "Presence", "PresenceDeserializer")) - optionNonNullDeserType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage :+ "Presence", "OptionNonNullDeserializer")) - emptyIsNullDeserType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage :+ "EmptyIsNullDeserializers", "EmptyIsNullDeserializer")) - emptyIsNullOptionDeserType <- ScalaGenerator().selectType( - NonEmptyList.ofInitLast(supportPackage :+ "EmptyIsNullDeserializers", "EmptyIsNullOptionDeserializer") + formattedClsName <- formatTypeName(clsName) + array <- fromArray(formattedClsName, arr, concreteTypes, components) + } yield array + ) + .orRefine { case o: ObjectSchema => o }(m => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum[Object](formattedClsName, m, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + m, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + alias <- modelTypeAlias(formattedClsName, m, components) + } yield enum.orElse(model).getOrElse(alias) ) - optionNonMissingDeserType <- ScalaGenerator().selectType( - NonEmptyList.ofInitLast(supportPackage :+ "Presence", "OptionNonMissingDeserializer") + .orRefine { case x: StringSchema => x }(x => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum(formattedClsName, x, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + x, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components + ) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + alias <- typeAlias(formattedClsName, declType) + } yield enum.orElse(model).getOrElse(alias) ) - allTerms = terms ++ parents.flatMap(_.params) - } yield renderedClass.copy( - mods = jsonIgnoreProperties +: renderedClass.mods, - ctor = renderedClass.ctor.copy( - paramss = renderedClass.ctor.paramss.map( - _.map(param => - allTerms - .find(_.term.name.value == param.name.value) - .fold(param) { term => - param.copy( - mods = paramAnnotations( - term, - presenceSerType, - presenceDeserType, - optionNonNullDeserType, - optionNonMissingDeserType, - emptyIsNullDeserType, - emptyIsNullOptionDeserType - ) ++ param.mods, - default = fixDefaultValue(term) - ) - } + .orRefine { case x: IntegerSchema => x }(x => + for { + formattedClsName <- formatTypeName(clsName) + enum <- fromEnum(formattedClsName, x, dtoPackage, components) + model <- fromModel( + NonEmptyList.of(formattedClsName), + x, + List.empty, + concreteTypes, + definitions.value, + dtoPackage, + supportPackage.toList, + defaultPropertyRequirement, + components ) - ) - ), - templ = renderedClass.templ.copy( - stats = discriminators.map { case (propertyName, fieldName, tpe, value) => - q""" - @com.fasterxml.jackson.annotation.JsonProperty(${Lit.String(propertyName)}) - val ${Pat.Var(Term.Name(fieldName))}: $tpe = $value - """ - } ++ renderedClass.templ.stats + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + alias <- typeAlias(formattedClsName, declType) + } yield enum.orElse(model).getOrElse(alias) + ) + .valueOr(x => + for { + formattedClsName <- formatTypeName(clsName) + customTypeName <- SwaggerUtil.customTypeName(x) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, customTypeName), components) + res <- typeAlias(formattedClsName, declType) + } yield res ) - ), - encodeModel = (_, _, _, _) => Target.pure(None), - decodeModel = (_, _, _, _, _) => Target.pure(None), - renderDTOStaticDefns = (className, deps, encoder, decoder, params) => + } + protoImports <- protocolImports() + pkgImports <- packageObjectImports() + pkgObjectContents <- packageObjectContents() + implicitsObject <- implicitsObject() + + polyADTElems <- ProtocolElems.resolve[ScalaLanguage, Target](polyADTs) + strictElems <- ProtocolElems.resolve[ScalaLanguage, Target](elems) + } yield ProtocolDefinitions[ScalaLanguage](strictElems ++ polyADTElems, protoImports, pkgImports, pkgObjectContents, implicitsObject)) + } + + private[this] def getRequiredFieldsRec(root: Tracker[Schema[_]]): List[String] = { + @scala.annotation.tailrec + def work(values: List[Tracker[Schema[_]]], acc: List[String]): List[String] = { + val required: List[String] = values.flatMap(_.downField("required", _.getRequired()).unwrapTracker) + val next: List[Tracker[Schema[_]]] = + for { + a <- values + b <- a.refine { case x: ComposedSchema => x }(_.downField("allOf", _.getAllOf())).toOption.toList + c <- b.indexedDistribute + } yield c + + val newRequired = acc ++ required + + next match { + case next @ (_ :: _) => work(next, newRequired) + case Nil => newRequired + } + } + work(List(root), Nil) + } + + private[this] def fromEnum[A]( + clsName: String, + schema: Tracker[Schema[A]], + dtoPackage: List[String], + components: Tracker[Option[Components]] + )(implicit + P: ProtocolTerms[ScalaLanguage, Target], + F: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target], + wrapEnumSchema: WrapEnumSchema[A] + ): Target[Either[String, EnumDefinition[ScalaLanguage]]] = { + import Sc._ + import Sw._ + + def validProg(held: HeldEnum, tpe: scala.meta.Type, fullType: scala.meta.Type): Target[EnumDefinition[ScalaLanguage]] = + for { + (pascalValues, wrappedValues) <- held match { + case StringHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(elem) + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedStringEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + case IntHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedIntEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + case LongHeldEnum(value) => + for { + elems <- value.traverse { elem => + for { + termName <- formatEnumName(s"${clsName}${elem}") // TODO: Push this string into LanguageTerms + valueTerm <- pureTermName(termName) + accessor <- buildAccessor(clsName, termName) + } yield (elem, valueTerm, accessor) + } + pascalValues = elems.map(_._2) + wrappedValues = RenderedLongEnum[ScalaLanguage](elems) + } yield (pascalValues, wrappedValues) + } + members <- renderMembers(clsName, wrappedValues) + encoder <- encodeEnum(clsName, tpe) + decoder <- decodeEnum(clsName, tpe) + + defn <- renderClass(clsName, tpe, wrappedValues) + staticDefns <- renderStaticDefns(clsName, tpe, members, pascalValues, encoder, decoder) + classType <- pureTypeName(clsName) + } yield EnumDefinition[ScalaLanguage](clsName, classType, fullType, wrappedValues, defn, staticDefns) + + for { + enum <- extractEnum(schema.map(wrapEnumSchema)) + customTpeName <- SwaggerUtil.customTypeName(schema) + (tpe, _) <- SwaggerUtil.determineTypeName(schema, Tracker.cloneHistory(schema, customTpeName), components) + fullType <- selectType(NonEmptyList.ofInitLast(dtoPackage, clsName)) + res <- enum.traverse(validProg(_, tpe, fullType)) + } yield res + } + + private[this] def getPropertyRequirement( + schema: Tracker[Schema[_]], + isRequired: Boolean, + defaultPropertyRequirement: PropertyRequirement + ): PropertyRequirement = + (for { + isNullable <- schema.downField("nullable", _.getNullable) + } yield (isRequired, isNullable) match { + case (true, None) => PropertyRequirement.Required + case (true, Some(false)) => PropertyRequirement.Required + case (true, Some(true)) => PropertyRequirement.RequiredNullable + case (false, None) => defaultPropertyRequirement + case (false, Some(false)) => PropertyRequirement.Optional + case (false, Some(true)) => PropertyRequirement.OptionalNullable + }).unwrapTracker + + /** Handle polymorphic model + */ + private[this] def fromPoly( + hierarchy: ClassParent[ScalaLanguage], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = { + import Sc._ + + def child(hierarchy: ClassHierarchy[ScalaLanguage]): List[String] = + hierarchy.children.map(_.name) ::: hierarchy.children.flatMap(child) + def parent(hierarchy: ClassHierarchy[ScalaLanguage]): List[String] = + if (hierarchy.children.nonEmpty) hierarchy.name :: hierarchy.children.flatMap(parent) + else Nil + + val children = child(hierarchy).diff(parent(hierarchy)).distinct + val discriminator = hierarchy.discriminator + + for { + parents <- hierarchy.model + .refine[Target[List[SuperClass[ScalaLanguage]]]] { case c: ComposedSchema => c }( + extractParents(_, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) + ) + .getOrElse(List.empty[SuperClass[ScalaLanguage]].pure[Target]) + props <- extractProperties(hierarchy.model) + requiredFields = hierarchy.required ::: hierarchy.children.flatMap(_.required) + params <- props.traverse { case (name, prop) => for { - renderedDTOStaticDefns <- baseInterp.renderDTOStaticDefns(className, deps, encoder, decoder, List.empty) - classType = Type.Name(className) - } yield renderedDTOStaticDefns.copy( - definitions = renderedDTOStaticDefns.definitions ++ List( - q"implicit val ${Pat.Var(Term.Name(s"encode${className}"))}: GuardrailEncoder[$classType] = GuardrailEncoder.instance", - q"implicit val ${Pat.Var(Term.Name(s"decode${className}"))}: GuardrailDecoder[$classType] = GuardrailDecoder.instance(new com.fasterxml.jackson.core.`type`.TypeReference[$classType] {})", - q"implicit val ${Pat.Var(Term.Name(s"validate${className}"))}: GuardrailValidator[$classType] = GuardrailValidator.instance" + typeName <- formatTypeName(name).map(formattedName => NonEmptyList.of(hierarchy.name, formattedName)) + propertyRequirement = getPropertyRequirement(prop, requiredFields.contains(name), defaultPropertyRequirement) + customType <- SwaggerUtil.customTypeName(prop) + resolvedType <- SwaggerUtil + .propMeta[ScalaLanguage, Target]( + prop, + components + ) // TODO: This should be resolved via an alternate mechanism that maintains references all the way through, instead of re-deriving and assuming that references are valid + defValue <- defaultValue(typeName, prop, propertyRequirement, definitions) + fieldName <- formatFieldName(name) + res <- transformProperty(hierarchy.name, dtoPackage, supportPackage, concreteTypes)( + name, + fieldName, + prop, + resolvedType, + propertyRequirement, + customType.isDefined, + defValue ) - ), - renderClass = (className, tpe, elems) => + } yield res + } + definition <- renderSealedTrait(hierarchy.name, params, discriminator, parents, children) + encoder <- encodeADT(hierarchy.name, hierarchy.discriminator, children) + decoder <- decodeADT(hierarchy.name, hierarchy.discriminator, children) + staticDefns <- renderADTStaticDefns(hierarchy.name, discriminator, encoder, decoder) + tpe <- pureTypeName(hierarchy.name) + fullType <- selectType(NonEmptyList.fromList(dtoPackage :+ hierarchy.name).getOrElse(NonEmptyList.of(hierarchy.name))) + } yield ADT[ScalaLanguage]( + name = hierarchy.name, + tpe = tpe, + fullType = fullType, + trt = definition, + staticDefns = staticDefns + ) + } + + private def extractParents( + elem: Tracker[ComposedSchema], + definitions: List[(String, Tracker[Schema[_]])], + concreteTypes: List[PropMeta[ScalaLanguage]], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[List[SuperClass[ScalaLanguage]]] = { + import Sc._ + + for { + a <- extractSuperClass(elem, definitions) + supper <- a.flatTraverse { case (clsName, _extends, interfaces) => + val concreteInterfacesWithClass = for { + interface <- interfaces + (cls, tracker) <- definitions + result <- tracker + .refine[Tracker[Schema[_]]] { + case x: ComposedSchema if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x + }( + identity _ + ) + .orRefine { case x: Schema[_] if interface.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/${cls}")) => x }(identity _) + .toOption + } yield cls -> result + val (_, concreteInterfaces) = concreteInterfacesWithClass.unzip + val classMapping = (for { + (cls, schema) <- concreteInterfacesWithClass + (name, _) <- schema.downField("properties", _.getProperties).indexedDistribute.value + } yield (name, cls)).toMap for { - renderedClass <- baseInterp.renderClass(className, tpe, elems) - } yield renderedClass.copy( - mods = List( - mod"@com.fasterxml.jackson.databind.annotation.JsonSerialize(using=classOf[${Type.Select(Term.Name(className), Type.Name(className + "Serializer"))}])", - mod"@com.fasterxml.jackson.databind.annotation.JsonDeserialize(using=classOf[${Type.Select(Term.Name(className), Type.Name(className + "Deserializer"))}])" - ) ++ renderedClass.mods - ), - encodeEnum = { (className, tpe) => + _extendsProps <- extractProperties(_extends) + requiredFields = getRequiredFieldsRec(_extends) ++ concreteInterfaces.flatMap(getRequiredFieldsRec) + _withProps <- concreteInterfaces.traverse(extractProperties) + props = _extendsProps ++ _withProps.flatten + (params, _) <- prepareProperties( + NonEmptyList.of(clsName), + classMapping, + props, + requiredFields, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + interfacesCls = interfaces.flatMap(_.downField("$ref", _.get$ref).unwrapTracker.map(_.split("/").last)) + tpe <- parseTypeName(clsName) + + discriminators <- (_extends :: concreteInterfaces).flatTraverse( + _.refine[Target[List[Discriminator[ScalaLanguage]]]] { case m: ObjectSchema => m }(m => Discriminator.fromSchema(m).map(_.toList)) + .getOrElse(List.empty[Discriminator[ScalaLanguage]].pure[Target]) + ) + } yield tpe + .map( + SuperClass[ScalaLanguage]( + clsName, + _, + interfacesCls, + params, + discriminators + ) + ) + .toList + } + + } yield supper + } + + private[this] def fromModel( + clsName: NonEmptyList[String], + model: Tracker[Schema[_]], + parents: List[SuperClass[ScalaLanguage]], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[Either[String, ClassDefinition[ScalaLanguage]]] = { + import Sc._ + + for { + props <- extractProperties(model) + requiredFields = getRequiredFieldsRec(model) + (params, nestedDefinitions) <- prepareProperties( + clsName, + Map.empty, + props, + requiredFields, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + encoder <- encodeModel(clsName.last, dtoPackage, params, parents) + decoder <- decodeModel(clsName.last, dtoPackage, supportPackage, params, parents) + tpe <- parseTypeName(clsName.last) + fullType <- selectType(dtoPackage.foldRight(clsName)((x, xs) => xs.prepend(x))) + staticDefns <- renderDTOStaticDefns(clsName.last, List.empty, encoder, decoder, params) + nestedClasses <- nestedDefinitions.flatTraverse { + case classDefinition: ClassDefinition[ScalaLanguage] => + for { + widenClass <- widenClassDefinition(classDefinition.cls) + companionTerm <- pureTermName(classDefinition.name) + companionDefinition <- wrapToObject(companionTerm, classDefinition.staticDefns.extraImports, classDefinition.staticDefns.definitions) + widenCompanion <- companionDefinition.traverse(widenObjectDefinition) + } yield List(widenClass) ++ widenCompanion.fold(classDefinition.staticDefns.definitions)(List(_)) + case enumDefinition: EnumDefinition[ScalaLanguage] => + for { + widenClass <- widenClassDefinition(enumDefinition.cls) + companionTerm <- pureTermName(enumDefinition.name) + companionDefinition <- wrapToObject(companionTerm, enumDefinition.staticDefns.extraImports, enumDefinition.staticDefns.definitions) + widenCompanion <- companionDefinition.traverse(widenObjectDefinition) + } yield List(widenClass) ++ widenCompanion.fold(enumDefinition.staticDefns.definitions)(List(_)) + } + defn <- renderDTOClass(clsName.last, supportPackage, params, parents) + } yield { + val finalStaticDefns = staticDefns.copy(definitions = staticDefns.definitions ++ nestedClasses) + if (parents.isEmpty && props.isEmpty) Left("Entity isn't model"): Either[String, ClassDefinition[ScalaLanguage]] + else tpe.toRight("Empty entity name").map(ClassDefinition[ScalaLanguage](clsName.last, _, fullType, defn, finalStaticDefns, parents)) + } + } + + private def prepareProperties( + clsName: NonEmptyList[String], + propertyToTypeLookup: Map[String, String], + props: List[(String, Tracker[Schema[_]])], + requiredFields: List[String], + concreteTypes: List[PropMeta[ScalaLanguage]], + definitions: List[(String, Tracker[Schema[_]])], + dtoPackage: List[String], + supportPackage: List[String], + defaultPropertyRequirement: PropertyRequirement, + components: Tracker[Option[Components]] + )(implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[(List[ProtocolParameter[ScalaLanguage]], List[NestedProtocolElems[ScalaLanguage]])] = { + import Sc._ + def getClsName(name: String): NonEmptyList[String] = propertyToTypeLookup.get(name).map(NonEmptyList.of(_)).getOrElse(clsName) + + def processProperty(name: String, schema: Tracker[Schema[_]]): Target[Option[Either[String, NestedProtocolElems[ScalaLanguage]]]] = + for { + nestedClassName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) + defn <- schema + .refine[Target[Option[Either[String, NestedProtocolElems[ScalaLanguage]]]]] { case x: ObjectSchema => x }(o => + for { + defn <- fromModel( + nestedClassName, + o, + List.empty, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + } yield Option(defn) + ) + .orRefine { case o: ComposedSchema => o }(o => + for { + parents <- extractParents(o, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components) + maybeClassDefinition <- fromModel( + nestedClassName, + o, + parents, + concreteTypes, + definitions, + dtoPackage, + supportPackage, + defaultPropertyRequirement, + components + ) + } yield Option(maybeClassDefinition) + ) + .orRefine { case a: ArraySchema => a }(_.downField("items", _.getItems()).indexedCosequence.flatTraverse(processProperty(name, _))) + .orRefine { case s: StringSchema if Option(s.getEnum).map(_.asScala).exists(_.nonEmpty) => s }(s => + fromEnum(nestedClassName.last, s, dtoPackage, components).map(Option(_)) + ) + .getOrElse(Option.empty[Either[String, NestedProtocolElems[ScalaLanguage]]].pure[Target]) + } yield defn + + for { + paramsAndNestedDefinitions <- props.traverse[Target, (Tracker[ProtocolParameter[ScalaLanguage]], Option[NestedProtocolElems[ScalaLanguage]])] { + case (name, schema) => + for { + typeName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName)) + tpe <- selectType(typeName) + maybeNestedDefinition <- processProperty(name, schema) + resolvedType <- SwaggerUtil.propMetaWithName(tpe, schema, components) + customType <- SwaggerUtil.customTypeName(schema) + propertyRequirement = getPropertyRequirement(schema, requiredFields.contains(name), defaultPropertyRequirement) + defValue <- defaultValue(typeName, schema, propertyRequirement, definitions) + fieldName <- formatFieldName(name) + parameter <- transformProperty(getClsName(name).last, dtoPackage, supportPackage, concreteTypes)( + name, + fieldName, + schema, + resolvedType, + propertyRequirement, + customType.isDefined, + defValue + ) + } yield (Tracker.cloneHistory(schema, parameter), maybeNestedDefinition.flatMap(_.toOption)) + } + (params, nestedDefinitions) = paramsAndNestedDefinitions.unzip + deduplicatedParams <- deduplicateParams(params) + unconflictedParams <- fixConflictingNames(deduplicatedParams) + } yield (unconflictedParams, nestedDefinitions.flatten) + } + + private def deduplicateParams( + params: List[Tracker[ProtocolParameter[ScalaLanguage]]] + )(implicit Sw: SwaggerTerms[ScalaLanguage, Target], Sc: LanguageTerms[ScalaLanguage, Target]): Target[List[ProtocolParameter[ScalaLanguage]]] = { + import Sc._ + Foldable[List] + .foldLeftM[Target, Tracker[ProtocolParameter[ScalaLanguage]], List[ProtocolParameter[ScalaLanguage]]]( + params, + List.empty[ProtocolParameter[ScalaLanguage]] + ) { (s, ta) => + val a = ta.unwrapTracker + s.find(p => p.name == a.name) match { + case None => (a :: s).pure[Target] + case Some(duplicate) => + for { + newDefaultValue <- findCommonDefaultValue(ta.showHistory, a.defaultValue, duplicate.defaultValue) + newRawType <- findCommonRawType(ta.showHistory, a.rawType, duplicate.rawType) + } yield { + val emptyToNull = if (Set(a.emptyToNull, duplicate.emptyToNull).contains(EmptyIsNull)) EmptyIsNull else EmptyIsEmpty + val redactionBehaviour = if (Set(a.dataRedaction, duplicate.dataRedaction).contains(DataRedacted)) DataRedacted else DataVisible + val mergedParameter = ProtocolParameter[ScalaLanguage]( + a.term, + a.baseType, + a.name, + a.dep, + newRawType, + a.readOnlyKey.orElse(duplicate.readOnlyKey), + emptyToNull, + redactionBehaviour, + a.propertyRequirement, + newDefaultValue, + a.propertyValidation + ) + mergedParameter :: s.filter(_.name != a.name) + } + } + } + .map(_.reverse) + } + + private def fixConflictingNames( + params: List[ProtocolParameter[ScalaLanguage]] + )(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[List[ProtocolParameter[ScalaLanguage]]] = { + import Lt._ + for { + paramsWithNames <- params.traverse(param => extractTermNameFromParam(param.term).map((_, param))) + counts = paramsWithNames.groupBy(_._1).view.mapValues(_.length).toMap + newParams <- paramsWithNames.traverse { case (name, param) => + if (counts.getOrElse(name, 0) > 1) { + for { + newTermName <- pureTermName(param.name.value) + newMethodParam <- alterMethodParameterName(param.term, newTermName) + } yield ProtocolParameter( + newMethodParam, + param.baseType, + param.name, + param.dep, + param.rawType, + param.readOnlyKey, + param.emptyToNull, + param.dataRedaction, + param.propertyRequirement, + param.defaultValue, + param.propertyValidation + ) + } else { + param.pure[Target] + } + } + } yield newParams + } + + private def modelTypeAlias(clsName: String, abstractModel: Tracker[Schema[_]], components: Tracker[Option[Components]])(implicit + Fw: FrameworkTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = { + import Fw._ + val model: Option[Tracker[ObjectSchema]] = abstractModel + .refine[Option[Tracker[ObjectSchema]]] { case m: ObjectSchema => m }(x => Option(x)) + .orRefine { case m: ComposedSchema => m }( + _.downField("allOf", _.getAllOf()).indexedCosequence + .get(1) + .flatMap( + _.refine { case o: ObjectSchema => o }(Option.apply) + .orRefineFallback(_ => None) + ) + ) + .orRefineFallback(_ => None) + for { + tpe <- model.fold[Target[scala.meta.Type]](objectType(None)) { m => for { - writeMethod <- tpe match { - case t"Int" => Target.pure(q"writeNumber") - case t"Long" => Target.pure(q"writeNumber") - case t"String" => Target.pure(q"writeString") - case other => Target.raiseException(s"Unexpected type during enumeration encoder: ${className} was ${other}") + tpeName <- SwaggerUtil.customTypeName[ScalaLanguage, Target, Tracker[ObjectSchema]](m) + (declType, _) <- SwaggerUtil.determineTypeName[ScalaLanguage, Target](m, Tracker.cloneHistory(m, tpeName), components) + } yield declType + } + res <- typeAlias(clsName, tpe) + } yield res + } + + private def plainTypeAlias( + clsName: String + )(implicit Fw: FrameworkTerms[ScalaLanguage, Target], Sc: LanguageTerms[ScalaLanguage, Target]): Target[ProtocolElems[ScalaLanguage]] = { + import Fw._ + for { + tpe <- objectType(None) + res <- typeAlias(clsName, tpe) + } yield res + } + + private def typeAlias(clsName: String, tpe: scala.meta.Type): Target[ProtocolElems[ScalaLanguage]] = + (RandomType[ScalaLanguage](clsName, tpe): ProtocolElems[ScalaLanguage]).pure[Target] + + private def fromArray(clsName: String, arr: Tracker[ArraySchema], concreteTypes: List[PropMeta[ScalaLanguage]], components: Tracker[Option[Components]])( + implicit + F: FrameworkTerms[ScalaLanguage, Target], + P: ProtocolTerms[ScalaLanguage, Target], + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[ProtocolElems[ScalaLanguage]] = + for { + deferredTpe <- SwaggerUtil.modelMetaType(arr, components) + tpe <- extractArrayType(deferredTpe, concreteTypes) + ret <- typeAlias(clsName, tpe) + } yield ret + + /** returns objects grouped into hierarchies + */ + private def groupHierarchies( + definitions: Mappish[List, String, Tracker[Schema[_]]] + )(implicit + Sc: LanguageTerms[ScalaLanguage, Target], + Sw: SwaggerTerms[ScalaLanguage, Target] + ): Target[(List[ClassParent[ScalaLanguage]], List[(String, Tracker[Schema[_]])])] = { + + def firstInHierarchy(model: Tracker[Schema[_]]): Option[Tracker[ObjectSchema]] = + model + .refine { case x: ComposedSchema => x } { elem => + definitions.value + .collectFirst { + case (clsName, element) + if elem.downField("allOf", _.getAllOf).exists(_.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$clsName"))) => + element + } + .flatMap( + _.refine { case x: ComposedSchema => x }(firstInHierarchy) + .orRefine { case o: ObjectSchema => o }(x => Option(x)) + .getOrElse(None) + ) + } + .getOrElse(None) + + def children(cls: String): List[ClassChild[ScalaLanguage]] = definitions.value.flatMap { case (clsName, comp) => + comp + .refine { case x: ComposedSchema => x }(comp => + if ( + comp + .downField("allOf", _.getAllOf()) + .exists(x => x.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$cls"))) + ) { + Some(ClassChild(clsName, comp, children(clsName), getRequiredFieldsRec(comp))) + } else None + ) + .getOrElse(None) + } + + def classHierarchy(cls: String, model: Tracker[Schema[_]]): Target[Option[ClassParent[ScalaLanguage]]] = + model + .refine { case c: ComposedSchema => c }(c => + firstInHierarchy(c) + .fold(Option.empty[Discriminator[ScalaLanguage]].pure[Target])(Discriminator.fromSchema[ScalaLanguage, Target]) + .map(_.map((_, getRequiredFieldsRec(c)))) + ) + .orRefine { case x: Schema[_] => x }(m => Discriminator.fromSchema(m).map(_.map((_, getRequiredFieldsRec(m))))) + .getOrElse(Option.empty[(Discriminator[ScalaLanguage], List[String])].pure[Target]) + .map(_.map { case (discriminator, reqFields) => ClassParent(cls, model, children(cls), discriminator, reqFields) }) + + Sw.log.function("groupHierarchies")( + definitions.value + .traverse { case (cls, model) => + for { + hierarchy <- classHierarchy(cls, model) + } yield hierarchy.filterNot(_.children.isEmpty).toLeft((cls, model)) + } + .map(_.partitionEither[ClassParent[ScalaLanguage], (String, Tracker[Schema[_]])](identity)) + ) + } + + private def defaultValue( + name: NonEmptyList[String], + schema: Tracker[Schema[_]], + requirement: PropertyRequirement, + definitions: List[(String, Tracker[Schema[_]])] + )(implicit + Sc: LanguageTerms[ScalaLanguage, Target], + Cl: CollectionsLibTerms[ScalaLanguage, Target] + ): Target[Option[scala.meta.Term]] = { + import Sc._ + import Cl._ + val empty = Option.empty[scala.meta.Term].pure[Target] + schema.downField("$ref", _.get$ref()).indexedDistribute match { + case Some(ref) => + definitions + .collectFirst { + case (cls, refSchema) if ref.unwrapTracker.endsWith(s"/$cls") => + defaultValue(NonEmptyList.of(cls), refSchema, requirement, definitions) } - } yield Some( - q""" + .getOrElse(empty) + case None => + schema + .refine { case map: MapSchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => map }(map => + for { + customTpe <- SwaggerUtil.customMapTypeName(map) + result <- customTpe.fold(emptyMap().map(Option(_)))(_ => empty) + } yield result + ) + .orRefine { case arr: ArraySchema if requirement == PropertyRequirement.Required || requirement == PropertyRequirement.RequiredNullable => arr }( + arr => + for { + customTpe <- SwaggerUtil.customArrayTypeName(arr) + result <- customTpe.fold(emptyArray().map(Option(_)))(_ => empty) + } yield result + ) + .orRefine { case p: BooleanSchema => p }(p => Default(p).extract[Boolean].fold(empty)(litBoolean(_).map(Some(_)))) + .orRefine { case p: NumberSchema if p.getFormat == "double" => p }(p => Default(p).extract[Double].fold(empty)(litDouble(_).map(Some(_)))) + .orRefine { case p: NumberSchema if p.getFormat == "float" => p }(p => Default(p).extract[Float].fold(empty)(litFloat(_).map(Some(_)))) + .orRefine { case p: IntegerSchema if p.getFormat == "int32" => p }(p => Default(p).extract[Int].fold(empty)(litInt(_).map(Some(_)))) + .orRefine { case p: IntegerSchema if p.getFormat == "int64" => p }(p => Default(p).extract[Long].fold(empty)(litLong(_).map(Some(_)))) + .orRefine { case p: StringSchema if Option(p.getEnum).map(_.asScala).exists(_.nonEmpty) => p }(p => + Default(p).extract[String] match { + case Some(defaultEnumValue) => + for { + enumName <- formatEnumName(defaultEnumValue) + result <- selectTerm(name.append(enumName)) + } yield Some(result) + case None => empty + } + ) + .orRefine { case p: StringSchema => p }(p => Default(p).extract[String].fold(empty)(litString(_).map(Some(_)))) + .getOrElse(empty) + } + } + + private def suffixClsName(prefix: String, clsName: String): Pat.Var = Pat.Var(Term.Name(s"${prefix}${clsName}")) + + private def lookupTypeName(tpeName: String, concreteTypes: List[PropMeta[ScalaLanguage]])(f: Type => Type): Option[Type] = + concreteTypes + .find(_.clsName == tpeName) + .map(_.tpe) + .map(f) + + private def renderMembers(clsName: String, elems: RenderedEnum[ScalaLanguage]) = { + val fields = elems match { + case RenderedStringEnum(elems) => + elems.map { case (value, termName, _) => + (termName, Lit.String(value)) + } + case RenderedIntEnum(elems) => + elems.map { case (value, termName, _) => + (termName, Lit.Int(value)) + } + case RenderedLongEnum(elems) => + elems.map { case (value, termName, _) => + (termName, Lit.Long(value)) + } + } + + Target.pure(Some(q""" + object members { + ..${fields.map { case (termName, lit) => q"""case object ${termName} extends ${Type.Name(clsName)}(${lit})""" }} + } + """)) + } + + private def encodeEnum(className: String, tpe: Type): Target[Option[Defn]] = + for { + writeMethod <- tpe match { + case t"Int" => Target.pure(q"writeNumber") + case t"Long" => Target.pure(q"writeNumber") + case t"String" => Target.pure(q"writeString") + case other => Target.raiseException(s"Unexpected type during enumeration encoder: ${className} was ${other}") + } + } yield Some( + q""" class ${Type.Name(className + "Serializer")} extends com.fasterxml.jackson.databind.JsonSerializer[${Type.Name(className)}] { override def serialize(value: ${Type - .Name( - className - )}, gen: com.fasterxml.jackson.core.JsonGenerator, serializers: com.fasterxml.jackson.databind.SerializerProvider): Unit = gen.${writeMethod}(value.value) + .Name( + className + )}, gen: com.fasterxml.jackson.core.JsonGenerator, serializers: com.fasterxml.jackson.databind.SerializerProvider): Unit = gen.${writeMethod}(value.value) } """ - ) - }, - decodeEnum = { (className, tpe) => - for { - getter <- tpe match { - case t"String" => Target.pure(q"getText") - case t"Int" => Target.pure(q"getIntValue") - case t"Long" => Target.pure(q"getLongValue") - case other => Target.raiseException(s"Unexpected type during enumeration decoder: ${className} was ${other}") - } - } yield Some( - q""" + ) + + private def decodeEnum(className: String, tpe: Type): Target[Option[Defn]] = + for { + getter <- tpe match { + case t"String" => Target.pure(q"getText") + case t"Int" => Target.pure(q"getIntValue") + case t"Long" => Target.pure(q"getLongValue") + case other => Target.raiseException(s"Unexpected type during enumeration decoder: ${className} was ${other}") + } + } yield Some( + q""" class ${Type.Name(className + "Deserializer")} extends com.fasterxml.jackson.databind.JsonDeserializer[${Type.Name(className)}] { override def deserialize(p: com.fasterxml.jackson.core.JsonParser, ctxt: com.fasterxml.jackson.databind.DeserializationContext): ${Type.Name( - className - )} = + className + )} = ${Term.Name(className)}.from(p.${getter}) .getOrElse({ throw new com.fasterxml.jackson.databind.JsonMappingException(p, s"Invalid value '$${p.${getter}}' for " + ${Lit - .String(className)}) }) + .String(className)}) }) } """ - ) - }, - renderStaticDefns = (className, elems, members, accessors, encoder, decoder) => - for { - renderedStaticDefns <- baseInterp.renderStaticDefns(className, elems, members, accessors, encoder, decoder) - classType = Type.Name(className) - } yield renderedStaticDefns.copy( - definitions = renderedStaticDefns.definitions ++ List( - q"implicit val ${Pat.Var(Term.Name(s"encode${className}"))}: GuardrailEncoder[$classType] = GuardrailEncoder.instance", - q"implicit val ${Pat.Var(Term.Name(s"decode${className}"))}: GuardrailDecoder[$classType] = GuardrailDecoder.instance(new com.fasterxml.jackson.core.`type`.TypeReference[$classType] {})", - q"implicit val ${Pat.Var(Term.Name(s"validate${className}"))}: GuardrailValidator[$classType] = GuardrailValidator.noop" - ) - ), - protocolImports = () => - Target.pure( + ) + + private def renderClass(className: String, tpe: scala.meta.Type, elems: RenderedEnum[ScalaLanguage]) = { + val jacksSer = + mod"@com.fasterxml.jackson.databind.annotation.JsonSerialize(using=classOf[${Type.Select(Term.Name(className), Type.Name(className + "Serializer"))}])" + val jacksDeser = + mod"@com.fasterxml.jackson.databind.annotation.JsonDeserialize(using=classOf[${Type.Select(Term.Name(className), Type.Name(className + "Deserializer"))}])" + Target.pure(q""" + ${jacksSer} ${jacksDeser} sealed abstract class ${Type.Name(className)}(val value: ${tpe}) extends _root_.scala.Product with _root_.scala.Serializable { + override def toString: String = value.toString + } + """) + } + + private def renderStaticDefns( + className: String, + tpe: scala.meta.Type, + members: Option[scala.meta.Defn.Object], + accessors: List[scala.meta.Term.Name], + encoder: Option[scala.meta.Defn], + decoder: Option[scala.meta.Defn] + ): Target[StaticDefns[ScalaLanguage]] = { + val longType = Type.Name(className) + val terms: List[Defn.Val] = accessors.map { pascalValue => + q"val ${Pat.Var(pascalValue)}: ${longType} = members.${pascalValue}" + }.toList + val values: Defn.Val = q"val values = _root_.scala.Vector(..$accessors)" + val implicits: List[Defn.Val] = List( + q"implicit val ${Pat.Var(Term.Name(s"show${className}"))}: Show[${longType}] = Show[${tpe}].contramap[${longType}](_.value)" + ) + Target.pure( + StaticDefns[ScalaLanguage]( + className = className, + extraImports = List.empty[Import], + definitions = members.toList ++ + terms ++ + List(Some(values), encoder, decoder).flatten ++ + implicits ++ List( - q"import cats.implicits._" + q"def from(value: ${tpe}): _root_.scala.Option[${longType}] = values.find(_.value == value)", + q"implicit val order: cats.Order[${longType}] = cats.Order.by[${longType}, Int](values.indexOf)" + ) ++ List( + q"implicit val ${Pat.Var(Term.Name(s"encode${className}"))}: GuardrailEncoder[$longType] = GuardrailEncoder.instance", + q"implicit val ${Pat.Var(Term.Name(s"decode${className}"))}: GuardrailDecoder[$longType] = GuardrailDecoder.instance(new com.fasterxml.jackson.core.`type`.TypeReference[$longType] {})", + q"implicit val ${Pat.Var(Term.Name(s"validate${className}"))}: GuardrailValidator[$longType] = GuardrailValidator.noop" ) - ), - packageObjectImports = () => Target.pure(List.empty), - packageObjectContents = () => Target.pure(List.empty), - implicitsObject = () => - Target.pure( - Some( - ( - q"JacksonImplicits", - q""" - object JacksonImplicits { - object constraints { - type NotNull = javax.validation.constraints.NotNull @scala.annotation.meta.field @scala.annotation.meta.param - } + ) + ) + } - trait GuardrailEncoder[A] { - def encode(a: A)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper): com.fasterxml.jackson.databind.JsonNode = - mapper.valueToTree(a) - } - object GuardrailEncoder { - def instance[A]: GuardrailEncoder[A] = new GuardrailEncoder[A] {} + override def buildAccessor(clsName: String, termName: String) = + Target.pure(q"${Term.Name(clsName)}.${Term.Name(termName)}") - implicit def guardrailEncodeOption[B: GuardrailEncoder]: GuardrailEncoder[Option[B]] = instance - implicit def guardrailEncodeVector[B: GuardrailEncoder]: GuardrailEncoder[Vector[B]] = instance - implicit def guardrailEncodeMap[B: GuardrailEncoder]: GuardrailEncoder[Map[String, B]] = instance - implicit val guardrailEncodeBoolean: GuardrailEncoder[Boolean] = instance - implicit val guardrailEncodeInt: GuardrailEncoder[Int] = instance - implicit val guardrailEncodeLong: GuardrailEncoder[Long] = instance - implicit val guardrailEncodeBigInt: GuardrailEncoder[BigInt] = instance - implicit val guardrailEncodeFloat: GuardrailEncoder[Float] = instance - implicit val guardrailEncodeDouble: GuardrailEncoder[Double] = instance - implicit val guardrailEncodeBigDecimal: GuardrailEncoder[BigDecimal] = instance - implicit val guardrailEncodeString: GuardrailEncoder[String] = instance - implicit val guardrailEncodeBase64String: GuardrailEncoder[Base64String] = instance - implicit val guardrailEncodeInstant: GuardrailEncoder[java.time.Instant] = instance - implicit val guardrailEncodeLocalDate: GuardrailEncoder[java.time.LocalDate] = instance - implicit val guardrailEncodeLocalDateTime: GuardrailEncoder[java.time.LocalDateTime] = instance - implicit val guardrailEncodeLocalTime: GuardrailEncoder[java.time.LocalTime] = instance - implicit val guardrailEncodeOffsetDateTime: GuardrailEncoder[java.time.OffsetDateTime] = instance - implicit val guardrailEncodeZonedDateTime: GuardrailEncoder[java.time.ZonedDateTime] = instance - } + private def extractProperties(swagger: Tracker[Schema[_]]) = + swagger + .refine[Target[List[(String, Tracker[Schema[_]])]]] { case o: ObjectSchema => o }(m => + Target.pure(m.downField("properties", _.getProperties()).indexedCosequence.value) + ) + .orRefine { case c: ComposedSchema => c } { comp => + val extractedProps = + comp + .downField("allOf", _.getAllOf()) + .indexedDistribute + .flatMap(_.downField("properties", _.getProperties).indexedCosequence.value) + Target.pure(extractedProps) + } + .orRefine { case x: Schema[_] if Option(x.get$ref()).isDefined => x }(comp => + Target.raiseUserError(s"Attempted to extractProperties for a ${comp.unwrapTracker.getClass()}, unsure what to do here (${comp.showHistory})") + ) + .getOrElse(Target.pure(List.empty[(String, Tracker[Schema[_]])])) - trait GuardrailDecoder[A] { - def tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[A], Class[A]] - def decode(jsonNode: com.fasterxml.jackson.databind.JsonNode)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper, validator: javax.validation.Validator, guardrailValidator: GuardrailValidator[A]): scala.util.Try[A] = - scala.util.Try(this.tpe.fold(mapper.convertValue(jsonNode, _), mapper.convertValue(jsonNode, _))).flatMap(guardrailValidator.validate) - } - object GuardrailDecoder { - def instance[A](typeRef: com.fasterxml.jackson.core.`type`.TypeReference[A]): GuardrailDecoder[A] = new GuardrailDecoder[A] { - override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[A], Class[A]] = Left(typeRef) - } - def instance[A](cls: Class[A]): GuardrailDecoder[A] = new GuardrailDecoder[A] { - override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[A], Class[A]] = Right(cls) - } + private def transformProperty( + clsName: String, + dtoPackage: List[String], + supportPackage: List[String], + concreteTypes: List[PropMeta[ScalaLanguage]] + )( + name: String, + fieldName: String, + property: Tracker[Schema[_]], + meta: core.ResolvedType[ScalaLanguage], + requirement: PropertyRequirement, + isCustomType: Boolean, + defaultValue: Option[scala.meta.Term] + ): Target[ProtocolParameter[ScalaLanguage]] = + Target.log.function(s"transformProperty") { + val fallbackRawType = ReifiedRawType.of(property.downField("type", _.getType()).unwrapTracker, property.downField("format", _.getFormat()).unwrapTracker) + for { + _ <- Target.log.debug(s"Args: (${clsName}, ${name}, ...)") - implicit def guardrailDecodeOption[B: GuardrailDecoder: GuardrailValidator]: GuardrailDecoder[Option[B]] = new GuardrailDecoder[Option[B]] { - override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[Option[B]], Class[Option[B]]] = Left(new com.fasterxml.jackson.core.`type`.TypeReference[Option[B]] {}) - override def decode(jsonNode: com.fasterxml.jackson.databind.JsonNode)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper, validator: javax.validation.Validator, guardrailValidator: GuardrailValidator[Option[B]]): scala.util.Try[Option[B]] = { - if (jsonNode.isNull) { - scala.util.Success(Option.empty[B]) - } else { - implicitly[GuardrailDecoder[B]].decode(jsonNode).map(Option.apply) - } - } - } - implicit def guardrailDecodeVector[B: GuardrailDecoder: GuardrailValidator]: GuardrailDecoder[Vector[B]] = new GuardrailDecoder[Vector[B]] { - override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[Vector[B]], Class[Vector[B]]] = Left(new com.fasterxml.jackson.core.`type`.TypeReference[Vector[B]] {}) - override def decode(jsonNode: com.fasterxml.jackson.databind.JsonNode)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper, validator: javax.validation.Validator, guardrailValidator: GuardrailValidator[Vector[B]]): scala.util.Try[Vector[B]] = { - jsonNode match { - case arr: com.fasterxml.jackson.databind.node.ArrayNode => - import cats.implicits._ - import _root_.scala.jdk.CollectionConverters._ - arr.iterator().asScala.toVector.traverse(implicitly[GuardrailDecoder[B]].decode) - case _ => - scala.util.Failure(new com.fasterxml.jackson.databind.JsonMappingException(null, s"Can't decode to vector; node of type $${jsonNode.getClass.getSimpleName} is not an array")) - } - } - } - implicit def guardrailDecodeMap[B: GuardrailDecoder: GuardrailValidator]: GuardrailDecoder[Map[String, B]] = new GuardrailDecoder[Map[String, B]] { - override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[Map[String, B]], Class[Map[String, B]]] = Left(new com.fasterxml.jackson.core.`type`.TypeReference[Map[String, B]] {}) - override def decode(jsonNode: com.fasterxml.jackson.databind.JsonNode)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper, validator: javax.validation.Validator, guardrailValidator: GuardrailValidator[Map[String, B]]): scala.util.Try[Map[String, B]] = { - jsonNode match { - case obj: com.fasterxml.jackson.databind.node.ObjectNode => - import cats.implicits._ - import _root_.scala.jdk.CollectionConverters._ - obj.fields().asScala.toVector.traverse(entry => implicitly[GuardrailDecoder[B]].decode(entry.getValue).map((entry.getKey, _))).map(_.toMap) - case _ => - scala.util.Failure(new com.fasterxml.jackson.databind.JsonMappingException(null, s"Can't decode to map; node of type $${jsonNode.getClass.getSimpleName} is not an object")) - } - } - } - implicit val guardrailDecodeBoolean: GuardrailDecoder[Boolean] = instance(classOf[Boolean]) - implicit val guardrailDecodeInt: GuardrailDecoder[Int] = instance(classOf[Int]) - implicit val guardrailDecodeLong: GuardrailDecoder[Long] = instance(classOf[Long]) - implicit val guardrailDecodeBigInt: GuardrailDecoder[BigInt] = instance(classOf[BigInt]) - implicit val guardrailDecodeFloat: GuardrailDecoder[Float] = instance(classOf[Float]) - implicit val guardrailDecodeDouble: GuardrailDecoder[Double] = instance(classOf[Double]) - implicit val guardrailDecodeBigDecimal: GuardrailDecoder[BigDecimal] = instance(classOf[BigDecimal]) - implicit val guardrailDecodeString: GuardrailDecoder[String] = instance(classOf[String]) - implicit val guardrailDecodeBase64String: GuardrailDecoder[Base64String] = instance(classOf[Base64String]) - implicit val guardrailDecodeInstant: GuardrailDecoder[java.time.Instant] = instance(classOf[java.time.Instant]) - implicit val guardrailDecodeLocalDate: GuardrailDecoder[java.time.LocalDate] = instance(classOf[java.time.LocalDate]) - implicit val guardrailDecodeLocalDateTime: GuardrailDecoder[java.time.LocalDateTime] = instance(classOf[java.time.LocalDateTime]) - implicit val guardrailDecodeLocalTime: GuardrailDecoder[java.time.LocalTime] = instance(classOf[java.time.LocalTime]) - implicit val guardrailDecodeOffsetDateTime: GuardrailDecoder[java.time.OffsetDateTime] = instance(classOf[java.time.OffsetDateTime]) - implicit val guardrailDecodeZonedDateTime: GuardrailDecoder[java.time.ZonedDateTime] = instance(classOf[java.time.ZonedDateTime]) - } + readOnlyKey = Option(name).filter(_ => property.downField("readOnly", _.getReadOnly()).unwrapTracker.contains(true)) + emptyToNull = property + .refine { case d: DateSchema => d }(d => EmptyValueIsNull(d)) + .orRefine { case dt: DateTimeSchema => dt }(dt => EmptyValueIsNull(dt)) + .orRefine { case s: StringSchema => s }(s => EmptyValueIsNull(s)) + .toOption + .flatten + .getOrElse(EmptyIsEmpty) - trait GuardrailValidator[A] { - def validate(a: A)(implicit validator: javax.validation.Validator): scala.util.Try[A] - } - object GuardrailValidator { - def instance[A]: GuardrailValidator[A] = new GuardrailValidator[A] { - override def validate(a: A)(implicit validator: javax.validation.Validator): scala.util.Try[A] = { - import _root_.scala.jdk.CollectionConverters._ - scala.util.Try(validator.validate(a)).flatMap({ - case violations if violations.isEmpty => - scala.util.Success(a) - case violations => - scala.util.Failure(new javax.validation.ValidationException(s"Validation of $${a.getClass.getSimpleName} failed: $${violations.asScala.map(viol => s"$${viol.getPropertyPath}: $${viol.getMessage}").mkString("; ")}")) - }) - } - } - def noop[A]: GuardrailValidator[A] = new GuardrailValidator[A] { - override def validate(a: A)(implicit validator: javax.validation.Validator): scala.util.Try[A] = scala.util.Success(a) + dataRedaction = DataRedaction(property).getOrElse(DataVisible) + + (tpe, classDep, rawType) <- meta match { // TODO: Target is not used + case core.Resolved(declType, classDep, _, rawType @ LiteralRawType(Some(rawTypeStr), rawFormat)) + if SwaggerUtil.isFile(rawTypeStr, rawFormat) && !isCustomType => + // assume that binary data are represented as a string. allow users to override. + Target.pure((t"String", classDep, rawType)) + case core.Resolved(declType, classDep, _, rawType) => + Target.pure((declType, classDep, rawType)) + case core.Deferred(tpeName) => + val tpe = concreteTypes.find(_.clsName == tpeName).map(_.tpe).getOrElse { + println(s"Unable to find definition for ${tpeName}, just inlining") + Type.Name(tpeName) + } + Target.pure((tpe, Option.empty, fallbackRawType)) + case core.DeferredArray(tpeName, containerTpe) => + val concreteType = lookupTypeName(tpeName, concreteTypes)(identity) + val innerType = concreteType.getOrElse(Type.Name(tpeName)) + val tpe = t"${containerTpe.getOrElse(t"_root_.scala.Vector")}[$innerType]" + Target.pure((tpe, Option.empty, ReifiedRawType.ofVector(fallbackRawType))) + case core.DeferredMap(tpeName, customTpe) => + val concreteType = lookupTypeName(tpeName, concreteTypes)(identity) + val innerType = concreteType.getOrElse(Type.Name(tpeName)) + val tpe = t"${customTpe.getOrElse(t"_root_.scala.Predef.Map")}[_root_.scala.Predef.String, $innerType]" + Target.pure((tpe, Option.empty, ReifiedRawType.ofMap(fallbackRawType))) + } + fieldPattern: Tracker[Option[String]] = property.downField("pattern", _.getPattern) + collectionElementPattern: Option[Tracker[String]] = + property.downField("items", _.getItems).indexedDistribute.flatMap(_.downField("pattern", _.getPattern).indexedDistribute) + + pattern = collectionElementPattern.fold(fieldPattern.map(PropertyValidations))( + _.map(regex => PropertyValidations(Some(regex))) + ) + + presence <- ScalaGenerator().selectTerm(NonEmptyList.ofInitLast(supportPackage, "Presence")) + presenceType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage, "Presence")) + (finalDeclType, finalDefaultValue) = requirement match { + case PropertyRequirement.Required => tpe -> defaultValue + case PropertyRequirement.Optional | PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) => + t"$presenceType[$tpe]" -> defaultValue.map(t => q"$presence.Present($t)").orElse(Some(q"$presence.Absent")) + case _: PropertyRequirement.OptionalRequirement | _: PropertyRequirement.Configured => + t"Option[$tpe]" -> defaultValue.map(t => q"Option($t)").orElse(Some(q"None")) + case PropertyRequirement.OptionalNullable => + t"$presenceType[Option[$tpe]]" -> defaultValue.map(t => q"$presence.Present($t)") + } + term = param"${Term.Name(fieldName)}: ${finalDeclType}".copy(default = finalDefaultValue) + dep = classDep.filterNot(_.value == clsName) // Filter out our own class name + } yield ProtocolParameter[ScalaLanguage]( + term, + tpe, + RawParameterName(name), + dep, + rawType, + readOnlyKey, + emptyToNull, + dataRedaction, + requirement, + finalDefaultValue, + pattern + ) + } + + private def renderDTOClass( + className: String, + supportPackage: List[String], + selfParams: List[ProtocolParameter[ScalaLanguage]], + parents: List[SuperClass[ScalaLanguage]] = Nil + ) = { + val discriminatorParams = + parents.flatMap(parent => parent.discriminators.flatMap(discrim => parent.params.find(_.name.value == discrim.propertyName).map((discrim, _)))) + for { + discriminators <- discriminatorParams.traverse { case (discriminator, param) => + for { + discrimTpe <- Target.fromOption(param.term.decltpe, RuntimeFailure(s"Property ${param.name.value} has no type")) + discrimValue <- JacksonHelpers + .discriminatorExpression[ScalaLanguage]( + param.name.value, + discriminatorValue(discriminator, className), + param.rawType + )( + v => Target.pure[Term](q"""BigInt(${Lit.String(v)})"""), + v => Target.pure[Term](q"""BigDecimal(${Lit.String(v)})"""), + v => + param.term.decltpe.fold( + Target.raiseUserError[Term](s"No declared type for property '${param.name.value}' on class $className") + ) { + case tpe @ (_: Type.Name | _: Type.Select) => + ScalaGenerator().formatEnumName(v).map(ev => Term.Select(Term.Name(tpe.toString), Term.Name(ev))) + case tpe => Target.raiseError(RuntimeFailure(s"Assumed property ${param.name.value} was an enum, but can't handle $tpe")) } + )(ScalaGenerator()) + } yield (param.name.value, param.term.name.value, param.term.decltpe, discrimValue) + } + presenceSerType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage :+ "Presence", "PresenceSerializer")) + presenceDeserType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage :+ "Presence", "PresenceDeserializer")) + optionNonNullDeserType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage :+ "Presence", "OptionNonNullDeserializer")) + emptyIsNullDeserType <- ScalaGenerator().selectType(NonEmptyList.ofInitLast(supportPackage :+ "EmptyIsNullDeserializers", "EmptyIsNullDeserializer")) + emptyIsNullOptionDeserType <- ScalaGenerator().selectType( + NonEmptyList.ofInitLast(supportPackage :+ "EmptyIsNullDeserializers", "EmptyIsNullOptionDeserializer") + ) + optionNonMissingDeserType <- ScalaGenerator().selectType( + NonEmptyList.ofInitLast(supportPackage :+ "Presence", "OptionNonMissingDeserializer") + ) + allTerms = selfParams ++ parents.flatMap(_.params) + renderedClass = { // TODO: This logic should be reflowed. The scope and rebindings is due to a refactor where + // code from another dependent class was just copied in here wholesale. + val discriminatorNames = parents.flatMap(_.discriminators).map(_.propertyName).toSet + val params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) + val terms = params.map(_.term) - implicit def guardrailValidateOption[A: GuardrailValidator]: GuardrailValidator[Option[A]] = new GuardrailValidator[Option[A]] { - override def validate(a: Option[A])(implicit validator: javax.validation.Validator): scala.util.Try[Option[A]] = - a.traverse(implicitly[GuardrailValidator[A]].validate) - } - implicit def guardrailValidateVector[A: GuardrailValidator]: GuardrailValidator[Vector[A]] = new GuardrailValidator[Vector[A]] { - override def validate(a: Vector[A])(implicit validator: javax.validation.Validator): scala.util.Try[Vector[A]] = - a.traverse(implicitly[GuardrailValidator[A]].validate) - } - implicit def guardrailValidateMap[A: GuardrailValidator]: GuardrailValidator[Map[String, A]] = new GuardrailValidator[Map[String, A]] { - override def validate(a: Map[String, A])(implicit validator: javax.validation.Validator): scala.util.Try[Map[String, A]] = - a.toVector.traverse({ case (k, v) => implicitly[GuardrailValidator[A]].validate(v).map((k, _)) }).map(_.toMap) - } - implicit val guardrailValidateBoolean: GuardrailValidator[Boolean] = noop - implicit val guardrailValidateInt: GuardrailValidator[Int] = noop - implicit val guardrailValidateLong: GuardrailValidator[Long] = noop - implicit val guardrailValidateBigInt: GuardrailValidator[BigInt] = noop - implicit val guardrailValidateFloat: GuardrailValidator[Float] = noop - implicit val guardrailValidateDouble: GuardrailValidator[Double] = noop - implicit val guardrailValidateBigDecimal: GuardrailValidator[BigDecimal] = noop - implicit val guardrailValidateString: GuardrailValidator[String] = instance - implicit val guardrailValidateBase64String: GuardrailValidator[Base64String] = instance - implicit val guardrailValidateInstant: GuardrailValidator[java.time.Instant] = noop - implicit val guardrailValidateLocalDate: GuardrailValidator[java.time.LocalDate] = noop - implicit val guardrailValidateLocalDateTime: GuardrailValidator[java.time.LocalDateTime] = noop - implicit val guardrailValidateLocalTime: GuardrailValidator[java.time.LocalTime] = noop - implicit val guardrailValidateOffsetDateTime: GuardrailValidator[java.time.OffsetDateTime] = noop - implicit val guardrailValidateZonedDateTime: GuardrailValidator[java.time.ZonedDateTime] = noop + val toStringMethod = if (params.exists(_.dataRedaction != DataVisible)) { + def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match { + case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()" + case _ => Lit.String("[redacted]") + } + + val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(","))) + + List[Defn.Def]( + q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${className}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}" + ) + } else { + List.empty[Defn.Def] + } + + val base = q"""$jsonIgnoreProperties case class ${Type.Name(className)}(..${terms.map(param => + allTerms + .find(_.term.name.value == param.name.value) + .fold(param) { term => + param.copy( + mods = paramAnnotations( + term, + presenceSerType, + presenceDeserType, + optionNonNullDeserType, + optionNonMissingDeserType, + emptyIsNullDeserType, + emptyIsNullOptionDeserType + ) ++ param.mods, + default = fixDefaultValue(term) + ) } - } - """ + )}) { + ..${discriminators.map { case (propertyName, fieldName, tpe, value) => + q""" + @com.fasterxml.jackson.annotation.JsonProperty(${Lit.String(propertyName)}) + val ${Pat.Var(Term.Name(fieldName))}: $tpe = $value + """ + }}; + ..$toStringMethod + }""" + + val parentOpt = if (parents.exists(s => s.discriminators.nonEmpty)) { + parents.headOption + } else { + None + } + parentOpt + .fold(base)(parent => + base.copy( + templ = + Template(Nil, (parent.clsName +: parent.interfaces).map(n => Init(Type.Name(n), Name(""), Nil)), Self(Name(""), None), base.templ.stats, Nil) ) ) - ), - generateSupportDefinitions = () => - for { - generatedSupportDefinitions <- baseInterp.generateSupportDefinitions() - } yield { - val (presence, others) = generatedSupportDefinitions.partition(_.className.value == "Presence") - presence.headOption - .map(defn => - defn.copy( - definition = defn.definition.map { - case q"object Presence { ..$stmts }" => - q""" - object Presence { + } + } yield renderedClass + } + + private def encodeModel( + clsName: String, + dtoPackage: List[String], + selfParams: List[ProtocolParameter[ScalaLanguage]], + parents: List[SuperClass[ScalaLanguage]] = Nil + ) = Target.pure(None) + + private def decodeModel( + clsName: String, + dtoPackage: List[String], + supportPackage: List[String], + selfParams: List[ProtocolParameter[ScalaLanguage]], + parents: List[SuperClass[ScalaLanguage]] = Nil + ): Target[Option[Defn.Val]] = Target.pure(None) + + private def renderDTOStaticDefns( + className: String, + deps: List[scala.meta.Term.Name], + encoder: Option[scala.meta.Defn.Val], + decoder: Option[scala.meta.Defn.Val], + protocolParameters: List[ProtocolParameter[ScalaLanguage]] + ) = { + val extraImports: List[Import] = deps.map { term => + q"import ${term}._" + } + val classType = Type.Name(className) + val encoderInstance = q"implicit val ${Pat.Var(Term.Name(s"encode${className}"))}: GuardrailEncoder[$classType] = GuardrailEncoder.instance" + val decoderInstance = + q"implicit val ${Pat.Var(Term.Name(s"decode${className}"))}: GuardrailDecoder[$classType] = GuardrailDecoder.instance(new com.fasterxml.jackson.core.`type`.TypeReference[$classType] {})" + val validatorInstance = q"implicit val ${Pat.Var(Term.Name(s"validate${className}"))}: GuardrailValidator[$classType] = GuardrailValidator.instance" + Target.pure( + StaticDefns[ScalaLanguage]( + className = className, + extraImports = extraImports, + definitions = (encoder ++ decoder ++ List(encoderInstance, decoderInstance, validatorInstance)).toList + ) + ) + } + + private def extractArrayType(arr: core.ResolvedType[ScalaLanguage], concreteTypes: List[PropMeta[ScalaLanguage]]) = + for { + result <- arr match { + case core.Resolved(tpe, dep, default, _) => Target.pure(tpe) + case core.Deferred(tpeName) => + Target.fromOption(lookupTypeName(tpeName, concreteTypes)(identity), UserError(s"Unresolved reference ${tpeName}")) + case core.DeferredArray(tpeName, containerTpe) => + Target.fromOption( + lookupTypeName(tpeName, concreteTypes)(tpe => t"${containerTpe.getOrElse(t"_root_.scala.Vector")}[${tpe}]"), + UserError(s"Unresolved reference ${tpeName}") + ) + case core.DeferredMap(tpeName, customTpe) => + Target.fromOption( + lookupTypeName(tpeName, concreteTypes)(tpe => + t"_root_.scala.Vector[${customTpe.getOrElse(t"_root_.scala.Predef.Map")}[_root_.scala.Predef.String, ${tpe}]]" + ), + UserError(s"Unresolved reference ${tpeName}") + ) + } + } yield result + + private def extractConcreteTypes(definitions: Either[String, List[PropMeta[ScalaLanguage]]]) = + definitions.fold[Target[List[PropMeta[ScalaLanguage]]]](Target.raiseUserError _, Target.pure _) + + private def protocolImports() = + Target.pure( + List( + q"import cats.implicits._" + ) + ) + + override def staticProtocolImports(pkgName: List[String]): Target[List[Import]] = { + val implicitsRef: Term.Ref = (pkgName.map(Term.Name.apply _) ++ List(q"Implicits")).foldLeft[Term.Ref](q"_root_")(Term.Select.apply _) + Target.pure( + List( + q"import cats.implicits._", + q"import cats.data.EitherT", + q"import io.circe.refined._", + q"import eu.timepit.refined.api.Refined", + q"import eu.timepit.refined.auto._", + q"import $implicitsRef._" + ) + ) + } + + private def packageObjectImports() = + Target.pure(List.empty) + + override def generateSupportDefinitions() = { + val presenceTrait = + q"""sealed trait Presence[+T] extends _root_.scala.Product with _root_.scala.Serializable { + def fold[R](ifAbsent: => R, + ifPresent: T => R): R + def map[R](f: T => R): Presence[R] = fold(Presence.absent, a => Presence.present(f(a))) + + def toOption: Option[T] = fold[Option[T]](None, Some(_)) + } + """ + val presenceObject = + q""" + object Presence { import com.fasterxml.jackson.annotation.JsonInclude import com.fasterxml.jackson.core.{JsonGenerator, JsonParser, JsonToken} import com.fasterxml.jackson.databind._ @@ -585,22 +1513,37 @@ object JacksonProtocolGenerator { ) } - ..$stmts - } - """ - case other => other - } - ) - ) - .toList ++ others ++ List( - SupportDefinition[ScalaLanguage]( - q"EmptyIsNullDeserializers", - List( - q"import com.fasterxml.jackson.core.{JsonParser, JsonToken}", - q"import com.fasterxml.jackson.databind.{DeserializationContext, JsonDeserializer, JsonMappingException}" - ), - List( - q""" + def absent[R]: Presence[R] = Absent + def present[R](value: R): Presence[R] = Present(value) + case object Absent extends Presence[Nothing] { + def fold[R](ifAbsent: => R, + ifValue: Nothing => R): R = ifAbsent + } + final case class Present[+T](value: T) extends Presence[T] { + def fold[R](ifAbsent: => R, + ifPresent: T => R): R = ifPresent(value) + } + + def fromOption[T](value: Option[T]): Presence[T] = + value.fold[Presence[T]](Absent)(Present(_)) + + implicit object PresenceFunctor extends cats.Functor[Presence] { + def map[A, B](fa: Presence[A])(f: A => B): Presence[B] = fa.fold[Presence[B]](Presence.absent, a => Presence.present(f(a))) + } + } + """ + val presenceDefinition = SupportDefinition[ScalaLanguage](q"Presence", Nil, List(presenceTrait, presenceObject), insideDefinitions = false) + Target.pure( + List( + presenceDefinition, + SupportDefinition[ScalaLanguage]( + q"EmptyIsNullDeserializers", + List( + q"import com.fasterxml.jackson.core.{JsonParser, JsonToken}", + q"import com.fasterxml.jackson.databind.{DeserializationContext, JsonDeserializer, JsonMappingException}" + ), + List( + q""" @SuppressWarnings(Array("org.wartremover.warts.Throw")) object EmptyIsNullDeserializers { class EmptyIsNullDeserializer extends JsonDeserializer[String] { @@ -626,55 +1569,284 @@ object JacksonProtocolGenerator { } } """ - ), - insideDefinitions = false - ) - ) - }, - renderSealedTrait = (className, params, discriminator, parents, children) => - for { - renderedTrait <- baseInterp.renderSealedTrait(className, params, discriminator, parents, children) - discriminatorParam <- Target.pure(params.find(_.name.value == discriminator.propertyName)) - } yield { - val subTypes = children.map(child => - q"new com.fasterxml.jackson.annotation.JsonSubTypes.Type(name = ${Lit.String(discriminatorValue(discriminator, child))}, value = classOf[${Type.Name(child)}])" - ) + ), + insideDefinitions = false + ) + ) + ) + } + + private def packageObjectContents() = + Target.pure( + List.empty + ) + + private def implicitsObject() = Target.pure( + Some( + ( + q"JacksonImplicits", + q""" + object JacksonImplicits { + object constraints { + type NotNull = javax.validation.constraints.NotNull @scala.annotation.meta.field @scala.annotation.meta.param + } + + trait GuardrailEncoder[A] { + def encode(a: A)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper): com.fasterxml.jackson.databind.JsonNode = + mapper.valueToTree(a) + } + object GuardrailEncoder { + def instance[A]: GuardrailEncoder[A] = new GuardrailEncoder[A] {} + + implicit def guardrailEncodeOption[B: GuardrailEncoder]: GuardrailEncoder[Option[B]] = instance + implicit def guardrailEncodeVector[B: GuardrailEncoder]: GuardrailEncoder[Vector[B]] = instance + implicit def guardrailEncodeMap[B: GuardrailEncoder]: GuardrailEncoder[Map[String, B]] = instance + implicit val guardrailEncodeBoolean: GuardrailEncoder[Boolean] = instance + implicit val guardrailEncodeInt: GuardrailEncoder[Int] = instance + implicit val guardrailEncodeLong: GuardrailEncoder[Long] = instance + implicit val guardrailEncodeBigInt: GuardrailEncoder[BigInt] = instance + implicit val guardrailEncodeFloat: GuardrailEncoder[Float] = instance + implicit val guardrailEncodeDouble: GuardrailEncoder[Double] = instance + implicit val guardrailEncodeBigDecimal: GuardrailEncoder[BigDecimal] = instance + implicit val guardrailEncodeString: GuardrailEncoder[String] = instance + implicit val guardrailEncodeBase64String: GuardrailEncoder[Base64String] = instance + implicit val guardrailEncodeInstant: GuardrailEncoder[java.time.Instant] = instance + implicit val guardrailEncodeLocalDate: GuardrailEncoder[java.time.LocalDate] = instance + implicit val guardrailEncodeLocalDateTime: GuardrailEncoder[java.time.LocalDateTime] = instance + implicit val guardrailEncodeLocalTime: GuardrailEncoder[java.time.LocalTime] = instance + implicit val guardrailEncodeOffsetDateTime: GuardrailEncoder[java.time.OffsetDateTime] = instance + implicit val guardrailEncodeZonedDateTime: GuardrailEncoder[java.time.ZonedDateTime] = instance + } + + trait GuardrailDecoder[A] { + def tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[A], Class[A]] + def decode(jsonNode: com.fasterxml.jackson.databind.JsonNode)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper, validator: javax.validation.Validator, guardrailValidator: GuardrailValidator[A]): scala.util.Try[A] = + scala.util.Try(this.tpe.fold(mapper.convertValue(jsonNode, _), mapper.convertValue(jsonNode, _))).flatMap(guardrailValidator.validate) + } + object GuardrailDecoder { + def instance[A](typeRef: com.fasterxml.jackson.core.`type`.TypeReference[A]): GuardrailDecoder[A] = new GuardrailDecoder[A] { + override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[A], Class[A]] = Left(typeRef) + } + def instance[A](cls: Class[A]): GuardrailDecoder[A] = new GuardrailDecoder[A] { + override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[A], Class[A]] = Right(cls) + } + + implicit def guardrailDecodeOption[B: GuardrailDecoder: GuardrailValidator]: GuardrailDecoder[Option[B]] = new GuardrailDecoder[Option[B]] { + override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[Option[B]], Class[Option[B]]] = Left(new com.fasterxml.jackson.core.`type`.TypeReference[Option[B]] {}) + override def decode(jsonNode: com.fasterxml.jackson.databind.JsonNode)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper, validator: javax.validation.Validator, guardrailValidator: GuardrailValidator[Option[B]]): scala.util.Try[Option[B]] = { + if (jsonNode.isNull) { + scala.util.Success(Option.empty[B]) + } else { + implicitly[GuardrailDecoder[B]].decode(jsonNode).map(Option.apply) + } + } + } + implicit def guardrailDecodeVector[B: GuardrailDecoder: GuardrailValidator]: GuardrailDecoder[Vector[B]] = new GuardrailDecoder[Vector[B]] { + override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[Vector[B]], Class[Vector[B]]] = Left(new com.fasterxml.jackson.core.`type`.TypeReference[Vector[B]] {}) + override def decode(jsonNode: com.fasterxml.jackson.databind.JsonNode)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper, validator: javax.validation.Validator, guardrailValidator: GuardrailValidator[Vector[B]]): scala.util.Try[Vector[B]] = { + jsonNode match { + case arr: com.fasterxml.jackson.databind.node.ArrayNode => + import cats.implicits._ + import _root_.scala.jdk.CollectionConverters._ + arr.iterator().asScala.toVector.traverse(implicitly[GuardrailDecoder[B]].decode) + case _ => + scala.util.Failure(new com.fasterxml.jackson.databind.JsonMappingException(null, s"Can't decode to vector; node of type $${jsonNode.getClass.getSimpleName} is not an array")) + } + } + } + implicit def guardrailDecodeMap[B: GuardrailDecoder: GuardrailValidator]: GuardrailDecoder[Map[String, B]] = new GuardrailDecoder[Map[String, B]] { + override val tpe: Either[com.fasterxml.jackson.core.`type`.TypeReference[Map[String, B]], Class[Map[String, B]]] = Left(new com.fasterxml.jackson.core.`type`.TypeReference[Map[String, B]] {}) + override def decode(jsonNode: com.fasterxml.jackson.databind.JsonNode)(implicit mapper: com.fasterxml.jackson.databind.ObjectMapper, validator: javax.validation.Validator, guardrailValidator: GuardrailValidator[Map[String, B]]): scala.util.Try[Map[String, B]] = { + jsonNode match { + case obj: com.fasterxml.jackson.databind.node.ObjectNode => + import cats.implicits._ + import _root_.scala.jdk.CollectionConverters._ + obj.fields().asScala.toVector.traverse(entry => implicitly[GuardrailDecoder[B]].decode(entry.getValue).map((entry.getKey, _))).map(_.toMap) + case _ => + scala.util.Failure(new com.fasterxml.jackson.databind.JsonMappingException(null, s"Can't decode to map; node of type $${jsonNode.getClass.getSimpleName} is not an object")) + } + } + } + implicit val guardrailDecodeBoolean: GuardrailDecoder[Boolean] = instance(classOf[Boolean]) + implicit val guardrailDecodeInt: GuardrailDecoder[Int] = instance(classOf[Int]) + implicit val guardrailDecodeLong: GuardrailDecoder[Long] = instance(classOf[Long]) + implicit val guardrailDecodeBigInt: GuardrailDecoder[BigInt] = instance(classOf[BigInt]) + implicit val guardrailDecodeFloat: GuardrailDecoder[Float] = instance(classOf[Float]) + implicit val guardrailDecodeDouble: GuardrailDecoder[Double] = instance(classOf[Double]) + implicit val guardrailDecodeBigDecimal: GuardrailDecoder[BigDecimal] = instance(classOf[BigDecimal]) + implicit val guardrailDecodeString: GuardrailDecoder[String] = instance(classOf[String]) + implicit val guardrailDecodeBase64String: GuardrailDecoder[Base64String] = instance(classOf[Base64String]) + implicit val guardrailDecodeInstant: GuardrailDecoder[java.time.Instant] = instance(classOf[java.time.Instant]) + implicit val guardrailDecodeLocalDate: GuardrailDecoder[java.time.LocalDate] = instance(classOf[java.time.LocalDate]) + implicit val guardrailDecodeLocalDateTime: GuardrailDecoder[java.time.LocalDateTime] = instance(classOf[java.time.LocalDateTime]) + implicit val guardrailDecodeLocalTime: GuardrailDecoder[java.time.LocalTime] = instance(classOf[java.time.LocalTime]) + implicit val guardrailDecodeOffsetDateTime: GuardrailDecoder[java.time.OffsetDateTime] = instance(classOf[java.time.OffsetDateTime]) + implicit val guardrailDecodeZonedDateTime: GuardrailDecoder[java.time.ZonedDateTime] = instance(classOf[java.time.ZonedDateTime]) + } + + trait GuardrailValidator[A] { + def validate(a: A)(implicit validator: javax.validation.Validator): scala.util.Try[A] + } + object GuardrailValidator { + def instance[A]: GuardrailValidator[A] = new GuardrailValidator[A] { + override def validate(a: A)(implicit validator: javax.validation.Validator): scala.util.Try[A] = { + import _root_.scala.jdk.CollectionConverters._ + scala.util.Try(validator.validate(a)).flatMap({ + case violations if violations.isEmpty => + scala.util.Success(a) + case violations => + scala.util.Failure(new javax.validation.ValidationException(s"Validation of $${a.getClass.getSimpleName} failed: $${violations.asScala.map(viol => s"$${viol.getPropertyPath}: $${viol.getMessage}").mkString("; ")}")) + }) + } + } + def noop[A]: GuardrailValidator[A] = new GuardrailValidator[A] { + override def validate(a: A)(implicit validator: javax.validation.Validator): scala.util.Try[A] = scala.util.Success(a) + } + + implicit def guardrailValidateOption[A: GuardrailValidator]: GuardrailValidator[Option[A]] = new GuardrailValidator[Option[A]] { + override def validate(a: Option[A])(implicit validator: javax.validation.Validator): scala.util.Try[Option[A]] = + a.traverse(implicitly[GuardrailValidator[A]].validate) + } + implicit def guardrailValidateVector[A: GuardrailValidator]: GuardrailValidator[Vector[A]] = new GuardrailValidator[Vector[A]] { + override def validate(a: Vector[A])(implicit validator: javax.validation.Validator): scala.util.Try[Vector[A]] = + a.traverse(implicitly[GuardrailValidator[A]].validate) + } + implicit def guardrailValidateMap[A: GuardrailValidator]: GuardrailValidator[Map[String, A]] = new GuardrailValidator[Map[String, A]] { + override def validate(a: Map[String, A])(implicit validator: javax.validation.Validator): scala.util.Try[Map[String, A]] = + a.toVector.traverse({ case (k, v) => implicitly[GuardrailValidator[A]].validate(v).map((k, _)) }).map(_.toMap) + } + implicit val guardrailValidateBoolean: GuardrailValidator[Boolean] = noop + implicit val guardrailValidateInt: GuardrailValidator[Int] = noop + implicit val guardrailValidateLong: GuardrailValidator[Long] = noop + implicit val guardrailValidateBigInt: GuardrailValidator[BigInt] = noop + implicit val guardrailValidateFloat: GuardrailValidator[Float] = noop + implicit val guardrailValidateDouble: GuardrailValidator[Double] = noop + implicit val guardrailValidateBigDecimal: GuardrailValidator[BigDecimal] = noop + implicit val guardrailValidateString: GuardrailValidator[String] = instance + implicit val guardrailValidateBase64String: GuardrailValidator[Base64String] = instance + implicit val guardrailValidateInstant: GuardrailValidator[java.time.Instant] = noop + implicit val guardrailValidateLocalDate: GuardrailValidator[java.time.LocalDate] = noop + implicit val guardrailValidateLocalDateTime: GuardrailValidator[java.time.LocalDateTime] = noop + implicit val guardrailValidateLocalTime: GuardrailValidator[java.time.LocalTime] = noop + implicit val guardrailValidateOffsetDateTime: GuardrailValidator[java.time.OffsetDateTime] = noop + implicit val guardrailValidateZonedDateTime: GuardrailValidator[java.time.ZonedDateTime] = noop + } + } + """ + ) + ) + ) + + private def extractSuperClass( + swagger: Tracker[ComposedSchema], + definitions: List[(String, Tracker[Schema[_]])] + ) = { + def allParents: Tracker[Schema[_]] => Target[List[(String, Tracker[Schema[_]], List[Tracker[Schema[_]]])]] = + _.refine[Target[List[(String, Tracker[Schema[_]], List[Tracker[Schema[_]]])]]] { case x: ComposedSchema => x }( + _.downField("allOf", _.getAllOf()).indexedDistribute.filter(_.downField("$ref", _.get$ref()).unwrapTracker.nonEmpty) match { + case head :: tail => + definitions + .collectFirst { + case (clsName, e) if head.downField("$ref", _.get$ref()).exists(_.unwrapTracker.endsWith(s"/$clsName")) => + val thisParent = (clsName, e, tail) + allParents(e).map(otherParents => thisParent :: otherParents) + } + .getOrElse( + Target.raiseUserError(s"Reference ${head.downField("$ref", _.get$ref()).unwrapTracker} not found among definitions (${head.showHistory})") + ) + case _ => Target.pure(List.empty) + } + ).getOrElse(Target.pure(List.empty)) + allParents(swagger) + } + + private def renderADTStaticDefns( + className: String, + discriminator: Discriminator[ScalaLanguage], + encoder: Option[scala.meta.Defn.Val], + decoder: Option[scala.meta.Defn.Val] + ) = { + val classType = Type.Name(className) + Target.pure( + StaticDefns[ScalaLanguage]( + className = className, + extraImports = List.empty[Import], + definitions = List[Option[Defn]]( + Some(q"val discriminator: String = ${Lit.String(discriminator.propertyName)}"), + encoder, // TODO: encoder/decoder and the following three defns look _very_ suspicious. Evaluate whether they need to move into the emitter methods. + decoder, + Some(q"implicit val ${Pat.Var(Term.Name(s"encode${className}"))}: GuardrailEncoder[$classType] = GuardrailEncoder.instance"), + Some( + q"implicit val ${Pat.Var(Term.Name(s"decode${className}"))}: GuardrailDecoder[$classType] = GuardrailDecoder.instance(new com.fasterxml.jackson.core.`type`.TypeReference[$classType] {})" + ), + Some(q"implicit val ${Pat.Var(Term.Name(s"validate${className}"))}: GuardrailValidator[$classType] = GuardrailValidator.instance") + ).flatten + ) + ) + } + + private def decodeADT(clsName: String, discriminator: Discriminator[ScalaLanguage], children: List[String] = Nil) = + Target.pure(None) + + private def encodeADT(clsName: String, discriminator: Discriminator[ScalaLanguage], children: List[String] = Nil) = + Target.pure(None) + + private def renderSealedTrait( + className: String, + params: List[ProtocolParameter[ScalaLanguage]], + discriminator: Discriminator[ScalaLanguage], + parents: List[SuperClass[ScalaLanguage]] = Nil, + children: List[String] = Nil + ) = { + val parent = for { + testTerms <- + params + .map(_.term) + .filter(_.name.value != discriminator.propertyName) + .traverse { t => + for { + tpe <- Target.fromOption( + t.decltpe + .flatMap { + case tpe: Type => Some(tpe) + case x => None + }, + UserError(t.decltpe.fold("Nothing to map")(x => s"Unsure how to map ${x.structure}, please report this bug!")) + ) + } yield q"""def ${Term.Name(t.name.value)}: ${tpe}""" + } + } yield parents.headOption + .fold(q"""trait ${Type.Name(className)} {..${testTerms}}""")(parent => + q"""trait ${Type.Name(className)} extends ${init"${Type.Name(parent.clsName)}(...$Nil)"} { ..${testTerms} } """ + ) + for { + renderedTrait <- parent + discriminatorParam <- Target.pure(params.find(_.name.value == discriminator.propertyName)) + } yield { + val subTypes = children.map(child => + q"new com.fasterxml.jackson.annotation.JsonSubTypes.Type(name = ${Lit.String(discriminatorValue(discriminator, child))}, value = classOf[${Type.Name(child)}])" + ) - renderedTrait.copy( - mods = List( - mod""" + renderedTrait.copy( + mods = List( + mod""" @com.fasterxml.jackson.annotation.JsonTypeInfo( use = com.fasterxml.jackson.annotation.JsonTypeInfo.Id.NAME, include = com.fasterxml.jackson.annotation.JsonTypeInfo.As.PROPERTY, property = ${Lit.String(discriminator.propertyName)} ) """, - mod""" + mod""" @com.fasterxml.jackson.annotation.JsonSubTypes(Array( ..$subTypes )) """ - ) ++ renderedTrait.mods, - templ = renderedTrait.templ.copy( - stats = renderedTrait.templ.stats ++ discriminatorParam.map(param => - q"def ${Term.Name(param.term.name.value)}: ${param.term.decltpe.getOrElse(t"Any")}" - ) - ) - ) - }, - encodeADT = (_, _, _) => Target.pure(None), - decodeADT = (_, _, _) => Target.pure(None), - renderADTStaticDefns = (className, discriminator, encoder, decoder) => - for { - renderedADTStaticDefns <- baseInterp.renderADTStaticDefns(className, discriminator, encoder, decoder) - classType = Type.Name(className) - } yield renderedADTStaticDefns.copy( - definitions = renderedADTStaticDefns.definitions ++ List( - q"implicit val ${Pat.Var(Term.Name(s"encode${className}"))}: GuardrailEncoder[$classType] = GuardrailEncoder.instance", - q"implicit val ${Pat.Var(Term.Name(s"decode${className}"))}: GuardrailDecoder[$classType] = GuardrailDecoder.instance(new com.fasterxml.jackson.core.`type`.TypeReference[$classType] {})", - q"implicit val ${Pat.Var(Term.Name(s"validate${className}"))}: GuardrailValidator[$classType] = GuardrailValidator.instance" - ) + ) ++ renderedTrait.mods, + templ = renderedTrait.templ.copy( + stats = + renderedTrait.templ.stats ++ discriminatorParam.map(param => q"def ${Term.Name(param.term.name.value)}: ${param.term.decltpe.getOrElse(t"Any")}") ) - ) + ) + } } } diff --git a/modules/scala-support/src/test/scala/tests/circe/ArrayValidationTest.scala b/modules/scala-support/src/test/scala/tests/circe/ArrayValidationTest.scala index aedf94325b..d27c7f2234 100644 --- a/modules/scala-support/src/test/scala/tests/circe/ArrayValidationTest.scala +++ b/modules/scala-support/src/test/scala/tests/circe/ArrayValidationTest.scala @@ -4,7 +4,6 @@ import cats.data.NonEmptyList import dev.guardrail.Target import dev.guardrail.core.Tracker import dev.guardrail.generators.ProtocolDefinitions -import dev.guardrail.generators.ProtocolGenerator import dev.guardrail.generators.SwaggerGenerator import dev.guardrail.generators.scala.CirceRefinedModelGenerator import dev.guardrail.generators.scala.ScalaCollectionsGenerator @@ -93,8 +92,8 @@ class ArrayValidationTest extends AnyFreeSpec with Matchers with SwaggerSpecRunn | pattern: "pet" |""".stripMargin - val ProtocolDefinitions(ClassDefinition(_, _, _, cls, staticDefns, _) :: Nil, _, _, _, _) = ProtocolGenerator - .fromSwagger[ScalaLanguage, Target]( + val ProtocolDefinitions(ClassDefinition(_, _, _, cls, staticDefns, _) :: Nil, _, _, _, _) = circeProtocolGenerator + .fromSwagger( Tracker(swaggerFromString(collectionElementsWithPattern)), dtoPackage = Nil, supportPackage = NonEmptyList.one("foop"), @@ -118,8 +117,8 @@ class ArrayValidationTest extends AnyFreeSpec with Matchers with SwaggerSpecRunn "should generate size boundary constrains" in { - val ProtocolDefinitions(ClassDefinition(_, _, _, cls, staticDefns, _) :: Nil, _, _, _, _) = ProtocolGenerator - .fromSwagger[ScalaLanguage, Target]( + val ProtocolDefinitions(ClassDefinition(_, _, _, cls, staticDefns, _) :: Nil, _, _, _, _) = circeProtocolGenerator + .fromSwagger( Tracker(swaggerFromString(swagger)), dtoPackage = Nil, supportPackage = NonEmptyList.one("foop"), diff --git a/modules/scala-support/src/test/scala/tests/circe/BigObjectSpec.scala b/modules/scala-support/src/test/scala/tests/circe/BigObjectSpec.scala index df4495ea1f..fc4df29d10 100644 --- a/modules/scala-support/src/test/scala/tests/circe/BigObjectSpec.scala +++ b/modules/scala-support/src/test/scala/tests/circe/BigObjectSpec.scala @@ -11,7 +11,6 @@ import support.SwaggerSpecRunner import dev.guardrail.Target import dev.guardrail.core.Tracker import dev.guardrail.generators.ProtocolDefinitions -import dev.guardrail.generators.ProtocolGenerator import dev.guardrail.generators.SwaggerGenerator import dev.guardrail.generators.scala.CirceModelGenerator import dev.guardrail.generators.scala.ScalaCollectionsGenerator @@ -148,8 +147,8 @@ class BigObjectSpec extends AnyFunSuite with Matchers with SwaggerSpecRunner { implicit val circeProtocolGenerator = CirceProtocolGenerator(CirceModelGenerator.V012) implicit val scalaGenerator = ScalaGenerator() implicit val swaggerGenerator = SwaggerGenerator[ScalaLanguage]() - val ProtocolDefinitions(ClassDefinition(_, _, _, cls, staticDefns, _) :: Nil, _, _, _, _) = ProtocolGenerator - .fromSwagger[ScalaLanguage, Target]( + val ProtocolDefinitions(ClassDefinition(_, _, _, cls, staticDefns, _) :: Nil, _, _, _, _) = circeProtocolGenerator + .fromSwagger( Tracker(swaggerFromString(swagger)), dtoPackage = Nil, supportPackage = NonEmptyList.one("foop"), diff --git a/modules/scala-support/src/test/scala/tests/circe/ValidationTest.scala b/modules/scala-support/src/test/scala/tests/circe/ValidationTest.scala index fe49926577..19a910c33c 100644 --- a/modules/scala-support/src/test/scala/tests/circe/ValidationTest.scala +++ b/modules/scala-support/src/test/scala/tests/circe/ValidationTest.scala @@ -4,7 +4,6 @@ import cats.data.NonEmptyList import dev.guardrail.Target import dev.guardrail.core.Tracker import dev.guardrail.generators.ProtocolDefinitions -import dev.guardrail.generators.ProtocolGenerator import dev.guardrail.generators.SwaggerGenerator import dev.guardrail.generators.scala.CirceRefinedModelGenerator import dev.guardrail.generators.scala.ScalaCollectionsGenerator @@ -68,8 +67,8 @@ class ValidationTest extends AnyFreeSpec with Matchers with SwaggerSpecRunner { implicit val circeProtocolGenerator: ProtocolTerms[ScalaLanguage, Target] = CirceRefinedProtocolGenerator(CirceRefinedModelGenerator.V012) implicit val scalaGenerator = ScalaGenerator() implicit val swaggerGenerator = SwaggerGenerator[ScalaLanguage]() - val ProtocolDefinitions(ClassDefinition(_, _, _, cls, staticDefns, _) :: Nil, _, _, _, _) = ProtocolGenerator - .fromSwagger[ScalaLanguage, Target]( + val ProtocolDefinitions(ClassDefinition(_, _, _, cls, staticDefns, _) :: Nil, _, _, _, _) = circeProtocolGenerator + .fromSwagger( Tracker(swaggerFromString(swagger)), dtoPackage = Nil, supportPackage = NonEmptyList.one("foop"),