From 401b43c13de8f68662fd4f1edd20677310caf3fa Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 9 Oct 2023 19:32:08 -0700 Subject: [PATCH] Moving extractSecuritySchemes into Common --- .../src/main/scala/dev/guardrail/Common.scala | 34 ++++++++++++++- .../scala/dev/guardrail/SwaggerUtil.scala | 41 ------------------- 2 files changed, 32 insertions(+), 43 deletions(-) delete mode 100644 modules/core/src/main/scala/dev/guardrail/SwaggerUtil.scala diff --git a/modules/core/src/main/scala/dev/guardrail/Common.scala b/modules/core/src/main/scala/dev/guardrail/Common.scala index f02ac81cb9..fba962e7e3 100644 --- a/modules/core/src/main/scala/dev/guardrail/Common.scala +++ b/modules/core/src/main/scala/dev/guardrail/Common.scala @@ -1,6 +1,7 @@ package dev.guardrail import _root_.io.swagger.v3.oas.models.OpenAPI +import io.swagger.v3.oas.models.security.{ SecurityScheme => SwSecurityScheme } import cats.data.NonEmptyList import cats.syntax.all._ import cats.Id @@ -8,6 +9,7 @@ import java.nio.file.Path import java.net.URI import dev.guardrail.core.{ SupportDefinition, Tracker } +import dev.guardrail.core.extract.CustomTypeName import dev.guardrail.generators.{ Clients, Servers } import dev.guardrail.generators.ProtocolDefinitions import dev.guardrail.languages.LA @@ -15,12 +17,40 @@ import dev.guardrail.terms.client.ClientTerms import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.RandomType import dev.guardrail.terms.server.ServerTerms -import dev.guardrail.terms.{ CollectionsLibTerms, CoreTerms, LanguageTerms, ProtocolTerms, SecurityRequirements, SwaggerTerms } +import dev.guardrail.terms.{ CollectionsLibTerms, CoreTerms, LanguageTerms, ProtocolTerms, SecurityRequirements, SecurityScheme, SwaggerTerms } object Common { val resolveFile: Path => List[String] => Path = root => _.foldLeft(root)(_.resolve(_)) val resolveFileNel: Path => NonEmptyList[String] => Path = root => _.foldLeft(root)(_.resolve(_)) + private[this] def extractSecuritySchemes[L <: LA, F[_]]( + spec: OpenAPI, + prefixes: List[String] + )(implicit Sw: SwaggerTerms[L, F], Sc: LanguageTerms[L, F]): F[Map[String, SecurityScheme[L]]] = { + import Sw._ + import Sc._ + + Tracker(spec) + .downField("components", _.getComponents) + .flatDownField("securitySchemes", _.getSecuritySchemes) + .indexedDistribute + .value + .flatTraverse { case (schemeName, scheme) => + val typeName = CustomTypeName(scheme, prefixes) + for { + tpe <- typeName.fold(Option.empty[L#Type].pure[F])(x => parseType(Tracker.cloneHistory(scheme, x))) + parsedScheme <- scheme.downField("type", _.getType).unwrapTracker.traverse { + case SwSecurityScheme.Type.APIKEY => extractApiKeySecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] + case SwSecurityScheme.Type.HTTP => extractHttpSecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] + case SwSecurityScheme.Type.OPENIDCONNECT => extractOpenIdConnectSecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] + case SwSecurityScheme.Type.OAUTH2 => extractOAuth2SecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] + case SwSecurityScheme.Type.MUTUALTLS => extractMutualTLSSecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] + } + } yield parsedScheme.toList.map(scheme => schemeName -> scheme) + } + .map(_.toMap) + } + def prepareDefinitions[L <: LA, F[_]]( kind: CodegenTarget, context: Context, @@ -82,7 +112,7 @@ object Common { requestBodies <- extractCommonRequestBodies(components) routes <- extractOperations(paths, requestBodies, globalSecurityRequirements) prefixes <- Cl.vendorPrefixes() - securitySchemes <- SwaggerUtil.extractSecuritySchemes(spec.unwrapTracker, prefixes) + securitySchemes <- extractSecuritySchemes(spec.unwrapTracker, prefixes) classNamedRoutes <- routes.traverse(route => getClassName(route.operation, prefixes, context.tagsBehaviour).map(_ -> route)) groupedRoutes = classNamedRoutes .groupMap(_._1)(_._2) diff --git a/modules/core/src/main/scala/dev/guardrail/SwaggerUtil.scala b/modules/core/src/main/scala/dev/guardrail/SwaggerUtil.scala deleted file mode 100644 index 81bf1b9c02..0000000000 --- a/modules/core/src/main/scala/dev/guardrail/SwaggerUtil.scala +++ /dev/null @@ -1,41 +0,0 @@ -package dev.guardrail - -import io.swagger.v3.oas.models._ -import io.swagger.v3.oas.models.security.{ SecurityScheme => SwSecurityScheme } -import cats.syntax.all._ -import dev.guardrail.core.Tracker -import dev.guardrail.core.implicits._ -import dev.guardrail.terms.{ LanguageTerms, SecurityScheme, SwaggerTerms } -import dev.guardrail.core.extract.CustomTypeName -import dev.guardrail.core.extract.VendorExtension.VendorExtensible._ -import dev.guardrail.languages.LA - -object SwaggerUtil { - def extractSecuritySchemes[L <: LA, F[_]]( - spec: OpenAPI, - prefixes: List[String] - )(implicit Sw: SwaggerTerms[L, F], Sc: LanguageTerms[L, F]): F[Map[String, SecurityScheme[L]]] = { - import Sw._ - import Sc._ - - Tracker(spec) - .downField("components", _.getComponents) - .flatDownField("securitySchemes", _.getSecuritySchemes) - .indexedDistribute - .value - .flatTraverse { case (schemeName, scheme) => - val typeName = CustomTypeName(scheme, prefixes) - for { - tpe <- typeName.fold(Option.empty[L#Type].pure[F])(x => parseType(Tracker.cloneHistory(scheme, x))) - parsedScheme <- scheme.downField("type", _.getType).unwrapTracker.traverse { - case SwSecurityScheme.Type.APIKEY => extractApiKeySecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] - case SwSecurityScheme.Type.HTTP => extractHttpSecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] - case SwSecurityScheme.Type.OPENIDCONNECT => extractOpenIdConnectSecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] - case SwSecurityScheme.Type.OAUTH2 => extractOAuth2SecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] - case SwSecurityScheme.Type.MUTUALTLS => extractMutualTLSSecurityScheme(schemeName, scheme, tpe).widen[SecurityScheme[L]] - } - } yield parsedScheme.toList.map(scheme => schemeName -> scheme) - } - .map(_.toMap) - } -}