Skip to content

Commit

Permalink
OpenAPI code gen with annotations fix (#2621)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jan 21, 2024
1 parent 2072383 commit 984d705
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 98 deletions.
92 changes: 41 additions & 51 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package zio.http.gen.openapi

import scala.annotation.tailrec

import zio.Chunk

import zio.http.Method
import zio.http.endpoint.openapi.OpenAPI.ReferenceOr
import zio.http.endpoint.openapi.{JsonSchema, OpenAPI}
import zio.http.gen.scala.Code
import zio.http.gen.scala.Code.ScalaType

import scala.annotation.tailrec
import zio.http.gen.scala.{Code, CodeGen}

object EndpointGen {

Expand All @@ -20,7 +21,6 @@ object EndpointGen {
private val DataImports =
List(
Code.Import("zio.schema._"),
Code.Import("zio._"),
)

private val RequestBodyRef = "#/components/requestBodies/(.*)".r
Expand Down Expand Up @@ -156,17 +156,13 @@ final case class EndpointGen() {
mt.schema match {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) =>
s"Chunk[$ref]"
case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive =>
s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]"
case JsonSchema.ArrayType(None) =>
"Chunk[String]"
case schema if schema.isPrimitive =>
schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString
case schema =>
case JsonSchema.Null =>
Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) =>
ref
case schema if schema.isPrimitive || schema.isCollection =>
CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType)
case schema =>
val code = schemaToCode(schema, openAPI, Inline.RequestBodyType, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for request body $schema"),
Expand Down Expand Up @@ -201,17 +197,11 @@ final case class EndpointGen() {
mt.schema match {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) =>
s"Chunk[$ref]"
case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive =>
s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]"
case JsonSchema.ArrayType(None) =>
"Chunk[String]"
case schema if schema.isPrimitive =>
schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString
case schema =>
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case schema if schema.isPrimitive || schema.isCollection =>
CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType)
case schema =>
val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for request body $schema"),
Expand Down Expand Up @@ -250,17 +240,11 @@ final case class EndpointGen() {
mt.schema match {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) =>
s"Chunk[$ref]"
case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive =>
s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]"
case JsonSchema.ArrayType(None) =>
"Chunk[String]"
case schema if schema.isPrimitive =>
schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString
case schema =>
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case schema if schema.isPrimitive || schema.isCollection =>
CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType)
case schema =>
val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for request body $schema"),
Expand Down Expand Up @@ -536,7 +520,7 @@ final case class EndpointGen() {
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = DataImports ++
imports = dataImports(caseClasses.flatMap(_.fields)) ++
(if (noDiscriminator || caseNames.nonEmpty) List(Code.Import("zio.schema.annotation._")) else Nil),
objects = Nil,
caseClasses = Nil,
Expand Down Expand Up @@ -583,7 +567,7 @@ final case class EndpointGen() {
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = DataImports,
imports = dataImports(fields),
objects = Nil,
caseClasses = List(
Code.CaseClass(
Expand Down Expand Up @@ -637,7 +621,7 @@ final case class EndpointGen() {
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = DataImports,
imports = dataImports(caseClasses.flatMap(_.fields)),
objects = Nil,
caseClasses = Nil,
enums = List(
Expand Down Expand Up @@ -666,23 +650,24 @@ final case class EndpointGen() {
.asInstanceOf[Code.Field]
if (required.contains(name)) field else field.copy(fieldType = field.fieldType.opt)
}.toList
val nested = properties.collect {
case (name, schema)
if !schema.isInstanceOf[JsonSchema.RefSchema]
&& !schema.isPrimitive
&& !schema.isCollection =>
schemaToCode(schema, openAPI, name.capitalize, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for field $name of object $name"),
)
}
val nested =
properties.map { case (name, schema) => name -> schema.withoutAnnotations }.collect {
case (name, schema)
if !schema.isInstanceOf[JsonSchema.RefSchema]
&& !schema.isPrimitive
&& !schema.isCollection =>
schemaToCode(schema, openAPI, name.capitalize, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for field $name of object $name"),
)
}
val nestedObjects = nested.flatMap(_.objects)
val nestedCaseClasses = nested.flatMap(_.caseClasses)
Some(
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = DataImports,
imports = dataImports(fields),
objects = nestedObjects.toList,
caseClasses = List(
Code.CaseClass(
Expand Down Expand Up @@ -792,4 +777,9 @@ final case class EndpointGen() {
Some(Code.Field(name, ScalaType.JsonAST))
}
}

private def dataImports(fields: Iterable[Code.Field]) = {
if (fields.exists(_.fieldType.isInstanceOf[Code.Collection.Seq])) List(Code.Import("zio._"))
else Nil
} ++ DataImports
}
15 changes: 13 additions & 2 deletions zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ object CodeGen {
throw new Exception("Files should be rendered separately")

case Code.File(_, path, imports, objects, caseClasses, enums) =>
s"package $basePackage.${path.mkString(".")}\n\n" +
s"package $basePackage${if (path.exists(_.nonEmpty)) path.mkString(if (basePackage.isEmpty) "" else ".", ".", "")
else ""}\n\n" +
s"${imports.map(render(basePackage)).mkString("\n")}\n\n" +
objects.map(render(basePackage)).mkString("\n") +
caseClasses.map(render(basePackage)).mkString("\n") +
Expand All @@ -58,7 +59,11 @@ object CodeGen {

case Code.Object(name, schema, endpoints, objects, caseClasses, enums) =>
s"object $name {\n" +
(if (endpoints.nonEmpty) EndpointImports.map(render(basePackage)).mkString("", "\n", "\n") else "") +
(if (endpoints.nonEmpty)
(EndpointImports ++ (if (endpointWithChunk(endpoints)) List(Code.Import("zio.http._")) else Nil))
.map(render(basePackage))
.mkString("", "\n", "\n")
else "") +
endpoints.map { case (k, v) => s"${render(basePackage)(k)}=${render(basePackage)(v)}" }
.mkString("\n") +
(if (schema) s"\n\n implicit val codec: Schema[$name] = DeriveSchema.gen[$name]" else "") +
Expand Down Expand Up @@ -138,6 +143,12 @@ object CodeGen {
throw new Exception(s"Unknown ScalaType: $scalaType")
}

private def endpointWithChunk(endpoints: Map[Code.Field, Code.EndpointCode]) =
endpoints.exists { case (_, code) =>
code.inCode.inType.contains("Chunk[") ||
(code.outCodes ++ code.errorsCode).exists(_.outType.contains("Chunk["))
}

def renderSegment(segment: Code.PathSegmentCode): String = segment match {
case Code.PathSegmentCode(name, segmentType) =>
segmentType match {
Expand Down
14 changes: 14 additions & 0 deletions zio-http-gen/src/test/resources/GeneratedUserNameArray.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package test.component

import zio._
import zio.schema._

case class UserNameArray(
id: Int,
name: Chunk[String],
)
object UserNameArray {

implicit val codec: Schema[UserNameArray] = DeriveSchema.gen[UserNameArray]

}
1 change: 0 additions & 1 deletion zio-http-gen/src/test/scala/zio/http/gen/model/User.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package zio.http.gen.model

import zio.schema._
import zio.schema.annotation._

case class User(id: Int, name: String)
object User {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zio.http.gen.model

import zio.Chunk

import zio.schema._

case class UserNameArray(id: Int, name: Chunk[String])
Expand Down
Loading

0 comments on commit 984d705

Please sign in to comment.