From 984d705fe090228b6ccc200968a29ceebd7bffc0 Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Sun, 21 Jan 2024 01:08:23 +0100 Subject: [PATCH] OpenAPI code gen with annotations fix (#2621) --- .../zio/http/gen/openapi/EndpointGen.scala | 92 ++++++------- .../scala/zio/http/gen/scala/CodeGen.scala | 15 +- .../resources/GeneratedUserNameArray.scala | 14 ++ .../test/scala/zio/http/gen/model/User.scala | 1 - .../zio/http/gen/model/UserNameArray.scala | 1 + .../http/gen/openapi/EndpointGenSpec.scala | 129 +++++++++++++----- .../zio/http/gen/scala/CodeGenSpec.scala | 17 +++ .../http/endpoint/openapi/OpenAPIGen.scala | 19 +-- 8 files changed, 190 insertions(+), 98 deletions(-) create mode 100644 zio-http-gen/src/test/resources/GeneratedUserNameArray.scala diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala index 2ff57082c6..f9e3097c01 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -1,13 +1,14 @@ package zio.http.gen.openapi +import scala.annotation.tailrec + import zio.Chunk + import zio.http.Method import zio.http.endpoint.openapi.OpenAPI.ReferenceOr import zio.http.endpoint.openapi.{JsonSchema, OpenAPI} -import zio.http.gen.scala.Code import zio.http.gen.scala.Code.ScalaType - -import scala.annotation.tailrec +import zio.http.gen.scala.{Code, CodeGen} object EndpointGen { @@ -20,7 +21,6 @@ object EndpointGen { private val DataImports = List( Code.Import("zio.schema._"), - Code.Import("zio._"), ) private val RequestBodyRef = "#/components/requestBodies/(.*)".r @@ -156,17 +156,13 @@ final case class EndpointGen() { mt.schema match { case ReferenceOr.Or(s) => s.withoutAnnotations match { - case JsonSchema.Null => Inline.Null - case JsonSchema.RefSchema(SchemaRef(ref)) => ref - case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) => - s"Chunk[$ref]" - case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive => - s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]" - case JsonSchema.ArrayType(None) => - "Chunk[String]" - case schema if schema.isPrimitive => - schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString - case schema => + case JsonSchema.Null => + Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => + ref + case schema if schema.isPrimitive || schema.isCollection => + CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType) + case schema => val code = schemaToCode(schema, openAPI, Inline.RequestBodyType, Chunk.empty) .getOrElse( throw new Exception(s"Could not generate code for request body $schema"), @@ -201,17 +197,11 @@ final case class EndpointGen() { mt.schema match { case ReferenceOr.Or(s) => s.withoutAnnotations match { - case JsonSchema.Null => Inline.Null - case JsonSchema.RefSchema(SchemaRef(ref)) => ref - case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) => - s"Chunk[$ref]" - case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive => - s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]" - case JsonSchema.ArrayType(None) => - "Chunk[String]" - case schema if schema.isPrimitive => - schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString - case schema => + case JsonSchema.Null => Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => ref + case schema if schema.isPrimitive || schema.isCollection => + CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType) + case schema => val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty) .getOrElse( throw new Exception(s"Could not generate code for request body $schema"), @@ -250,17 +240,11 @@ final case class EndpointGen() { mt.schema match { case ReferenceOr.Or(s) => s.withoutAnnotations match { - case JsonSchema.Null => Inline.Null - case JsonSchema.RefSchema(SchemaRef(ref)) => ref - case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) => - s"Chunk[$ref]" - case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive => - s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]" - case JsonSchema.ArrayType(None) => - "Chunk[String]" - case schema if schema.isPrimitive => - schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString - case schema => + case JsonSchema.Null => Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => ref + case schema if schema.isPrimitive || schema.isCollection => + CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType) + case schema => val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty) .getOrElse( throw new Exception(s"Could not generate code for request body $schema"), @@ -536,7 +520,7 @@ final case class EndpointGen() { Code.File( List("component", name.capitalize + ".scala"), pkgPath = List("component"), - imports = DataImports ++ + imports = dataImports(caseClasses.flatMap(_.fields)) ++ (if (noDiscriminator || caseNames.nonEmpty) List(Code.Import("zio.schema.annotation._")) else Nil), objects = Nil, caseClasses = Nil, @@ -583,7 +567,7 @@ final case class EndpointGen() { Code.File( List("component", name.capitalize + ".scala"), pkgPath = List("component"), - imports = DataImports, + imports = dataImports(fields), objects = Nil, caseClasses = List( Code.CaseClass( @@ -637,7 +621,7 @@ final case class EndpointGen() { Code.File( List("component", name.capitalize + ".scala"), pkgPath = List("component"), - imports = DataImports, + imports = dataImports(caseClasses.flatMap(_.fields)), objects = Nil, caseClasses = Nil, enums = List( @@ -666,23 +650,24 @@ final case class EndpointGen() { .asInstanceOf[Code.Field] if (required.contains(name)) field else field.copy(fieldType = field.fieldType.opt) }.toList - val nested = properties.collect { - case (name, schema) - if !schema.isInstanceOf[JsonSchema.RefSchema] - && !schema.isPrimitive - && !schema.isCollection => - schemaToCode(schema, openAPI, name.capitalize, Chunk.empty) - .getOrElse( - throw new Exception(s"Could not generate code for field $name of object $name"), - ) - } + val nested = + properties.map { case (name, schema) => name -> schema.withoutAnnotations }.collect { + case (name, schema) + if !schema.isInstanceOf[JsonSchema.RefSchema] + && !schema.isPrimitive + && !schema.isCollection => + schemaToCode(schema, openAPI, name.capitalize, Chunk.empty) + .getOrElse( + throw new Exception(s"Could not generate code for field $name of object $name"), + ) + } val nestedObjects = nested.flatMap(_.objects) val nestedCaseClasses = nested.flatMap(_.caseClasses) Some( Code.File( List("component", name.capitalize + ".scala"), pkgPath = List("component"), - imports = DataImports, + imports = dataImports(fields), objects = nestedObjects.toList, caseClasses = List( Code.CaseClass( @@ -792,4 +777,9 @@ final case class EndpointGen() { Some(Code.Field(name, ScalaType.JsonAST)) } } + + private def dataImports(fields: Iterable[Code.Field]) = { + if (fields.exists(_.fieldType.isInstanceOf[Code.Collection.Seq])) List(Code.Import("zio._")) + else Nil + } ++ DataImports } diff --git a/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala index 4760289be4..96d51a150b 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala @@ -44,7 +44,8 @@ object CodeGen { throw new Exception("Files should be rendered separately") case Code.File(_, path, imports, objects, caseClasses, enums) => - s"package $basePackage.${path.mkString(".")}\n\n" + + s"package $basePackage${if (path.exists(_.nonEmpty)) path.mkString(if (basePackage.isEmpty) "" else ".", ".", "") + else ""}\n\n" + s"${imports.map(render(basePackage)).mkString("\n")}\n\n" + objects.map(render(basePackage)).mkString("\n") + caseClasses.map(render(basePackage)).mkString("\n") + @@ -58,7 +59,11 @@ object CodeGen { case Code.Object(name, schema, endpoints, objects, caseClasses, enums) => s"object $name {\n" + - (if (endpoints.nonEmpty) EndpointImports.map(render(basePackage)).mkString("", "\n", "\n") else "") + + (if (endpoints.nonEmpty) + (EndpointImports ++ (if (endpointWithChunk(endpoints)) List(Code.Import("zio.http._")) else Nil)) + .map(render(basePackage)) + .mkString("", "\n", "\n") + else "") + endpoints.map { case (k, v) => s"${render(basePackage)(k)}=${render(basePackage)(v)}" } .mkString("\n") + (if (schema) s"\n\n implicit val codec: Schema[$name] = DeriveSchema.gen[$name]" else "") + @@ -138,6 +143,12 @@ object CodeGen { throw new Exception(s"Unknown ScalaType: $scalaType") } + private def endpointWithChunk(endpoints: Map[Code.Field, Code.EndpointCode]) = + endpoints.exists { case (_, code) => + code.inCode.inType.contains("Chunk[") || + (code.outCodes ++ code.errorsCode).exists(_.outType.contains("Chunk[")) + } + def renderSegment(segment: Code.PathSegmentCode): String = segment match { case Code.PathSegmentCode(name, segmentType) => segmentType match { diff --git a/zio-http-gen/src/test/resources/GeneratedUserNameArray.scala b/zio-http-gen/src/test/resources/GeneratedUserNameArray.scala new file mode 100644 index 0000000000..ba845e63ed --- /dev/null +++ b/zio-http-gen/src/test/resources/GeneratedUserNameArray.scala @@ -0,0 +1,14 @@ +package test.component + +import zio._ +import zio.schema._ + +case class UserNameArray( + id: Int, + name: Chunk[String], +) +object UserNameArray { + + implicit val codec: Schema[UserNameArray] = DeriveSchema.gen[UserNameArray] + +} diff --git a/zio-http-gen/src/test/scala/zio/http/gen/model/User.scala b/zio-http-gen/src/test/scala/zio/http/gen/model/User.scala index 782fe12a83..2fab37e36c 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/model/User.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/model/User.scala @@ -1,7 +1,6 @@ package zio.http.gen.model import zio.schema._ -import zio.schema.annotation._ case class User(id: Int, name: String) object User { diff --git a/zio-http-gen/src/test/scala/zio/http/gen/model/UserNameArray.scala b/zio-http-gen/src/test/scala/zio/http/gen/model/UserNameArray.scala index 4a17fe8abe..3e5bb6f0f6 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/model/UserNameArray.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/model/UserNameArray.scala @@ -1,6 +1,7 @@ package zio.http.gen.model import zio.Chunk + import zio.schema._ case class UserNameArray(id: Int, name: Chunk[String]) diff --git a/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala index 42fa303b52..838ac0c8ab 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala @@ -1,6 +1,10 @@ package zio.http.gen.openapi +import java.nio.file._ + import zio._ +import zio.test._ + import zio.http._ import zio.http.codec.HeaderCodec import zio.http.codec.HttpCodec.{query, queryInt} @@ -9,9 +13,6 @@ import zio.http.endpoint.openapi.JsonSchema.SchemaStyle.Inline import zio.http.endpoint.openapi.{OpenAPI, OpenAPIGen} import zio.http.gen.model._ import zio.http.gen.scala.Code -import zio.test._ - -import java.nio.file._ object EndpointGenSpec extends ZIOSpecDefault { override def spec: Spec[TestEnvironment with Scope, Any] = @@ -712,7 +713,10 @@ object EndpointGenSpec extends ZIOSpecDefault { val expected = Code.File( List("component", "Payment.scala"), pkgPath = List("component"), - imports = List(Code.Import(name = "zio.schema._"), Code.Import(name = "zio.schema.annotation._")), + imports = List( + Code.Import(name = "zio.schema._"), + Code.Import(name = "zio.schema.annotation._"), + ), objects = List.empty, caseClasses = List.empty, enums = List( @@ -749,7 +753,10 @@ object EndpointGenSpec extends ZIOSpecDefault { val expected = Code.File( List("component", "PaymentNamedDiscriminator.scala"), pkgPath = List("component"), - imports = List(Code.Import(name = "zio.schema._"), Code.Import(name = "zio.schema.annotation._")), + imports = List( + Code.Import(name = "zio.schema._"), + Code.Import(name = "zio.schema.annotation._"), + ), objects = List.empty, caseClasses = List.empty, enums = List( @@ -788,7 +795,10 @@ object EndpointGenSpec extends ZIOSpecDefault { val expected = Code.File( List("component", "PaymentNoDiscriminator.scala"), pkgPath = List("component"), - imports = List(Code.Import(name = "zio.schema._"), Code.Import(name = "zio.schema.annotation._")), + imports = List( + Code.Import(name = "zio.schema._"), + Code.Import(name = "zio.schema.annotation._"), + ), objects = List.empty, caseClasses = List.empty, enums = List( @@ -992,12 +1002,9 @@ object EndpointGenSpec extends ZIOSpecDefault { }, test("generates case class with seq field for request") { val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[UserNameArray].out[User] - val openAPI = OpenAPIGen.fromEndpoints("", "", Inline, endpoint) + val openAPI = OpenAPIGen.fromEndpoints("", "", endpoint) val scala = EndpointGen.fromOpenAPI(openAPI) - val fields = List( - Code.Field("id", Code.Primitive.ScalaInt), - Code.Field("name", Code.Primitive.ScalaString), - ) + println(openAPI.toJsonPretty) val expected = Code.File( List("api", "v1", "Users.scala"), pkgPath = List("api", "v1"), @@ -1014,36 +1021,86 @@ object EndpointGenSpec extends ZIOSpecDefault { ), queryParamsCode = Set.empty, headersCode = Code.HeadersCode.empty, - inCode = Code.InCode("POST.RequestBody"), - outCodes = List(Code.OutCode.json("POST.ResponseBody", Status.Ok)), + inCode = Code.InCode("UserNameArray"), + outCodes = List(Code.OutCode.json("User", Status.Ok)), errorsCode = Nil, ), ), - objects = List( - Code.Object( - "POST", - schema = false, - endpoints = Map.empty, - objects = Nil, - caseClasses = List( - Code - .CaseClass( - "RequestBody", - fields = List( - Code.Field("id", Code.Primitive.ScalaInt), - Code.Field("name", Code.Primitive.ScalaString.seq), - ), - companionObject = Some(Code.Object.schemaCompanion("RequestBody")), - ), - Code.CaseClass( - "ResponseBody", - fields = fields, - companionObject = Some(Code.Object.schemaCompanion("ResponseBody")), - ), - ), - enums = Nil, + objects = Nil, + caseClasses = Nil, + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("generates code from openapi with examples") { + val openapiJson = """{ + | "openapi": "3.0.3", + | "info": { + | "title": "Example", + | "description": "Example API documentation", + | "version": "1.0.0" + | }, + | "paths": { + | "/foo": { + | "post": { + | "requestBody": { + | "content": { + | "application/json": { + | "schema": { + | "$ref": "#/components/schemas/Bar" + | } + | } + | } + | }, + | "responses": { + | "200": { + | "description": "Success" + | } + | } + | } + | } + | }, + | "components": { + | "schemas": { + | "Bar": { + | "type": "object", + | "properties": { + | "stringField": { + | "type": "string", + | "example": "abc" + | } + | } + | } + | } + | } + |} + """.stripMargin + val openAPI = OpenAPI.fromJson(openapiJson).toOption.get + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("", "Foo.scala"), + pkgPath = List(""), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Foo", + schema = false, + endpoints = Map( + Code.Field("post") -> Code.EndpointCode( + Method.POST, + Code.PathPatternCode(segments = List(Code.PathSegmentCode("foo"))), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Bar"), + outCodes = List(Code.OutCode.json("Unit", Status.Ok)), + errorsCode = Nil, ), ), + objects = Nil, caseClasses = Nil, enums = Nil, ), diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index 129264e257..a336cbbb2b 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -12,6 +12,7 @@ import zio.test._ import zio.http._ import zio.http.codec._ import zio.http.endpoint.Endpoint +import zio.http.endpoint.openapi.JsonSchema.SchemaStyle.Inline import zio.http.endpoint.openapi.{OpenAPI, OpenAPIGen} import zio.http.gen.model._ import zio.http.gen.openapi.EndpointGen @@ -149,6 +150,7 @@ object CodeGenSpec extends ZIOSpecDefault { val code = EndpointGen.fromOpenAPI(openAPI) val tempDir = Files.createTempDirectory("codegen") + println(tempDir) CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) fileShouldBe( @@ -206,6 +208,7 @@ object CodeGenSpec extends ZIOSpecDefault { val code = EndpointGen.fromOpenAPI(openAPI) val tempDir = Files.createTempDirectory("codegen") + println(tempDir) CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) fileShouldBe( @@ -214,5 +217,19 @@ object CodeGenSpec extends ZIOSpecDefault { "/GeneratedValues.scala", ) }, + test("Endpoint with array field in input") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[UserNameArray].out[User] + val openAPI = OpenAPIGen.fromEndpoints("", "", endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe( + tempDir, + "test/component/UserNameArray.scala", + "/GeneratedUserNameArray.scala", + ) + }, ) @@ java11OrNewer @@ flaky // Downloading scalafmt on CI is flaky } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 407f875de1..af399517b5 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -1,20 +1,23 @@ package zio.http.endpoint.openapi +import java.util.UUID + +import scala.annotation.tailrec +import scala.collection.{immutable, mutable} + import zio.Chunk -import zio.http._ -import zio.http.codec.HttpCodec.Metadata -import zio.http.codec._ -import zio.http.endpoint._ -import zio.http.endpoint.openapi.JsonSchema.SchemaStyle import zio.json.EncoderOps import zio.json.ast.Json + import zio.schema.Schema.Record import zio.schema.codec.JsonCodec import zio.schema.{Schema, TypeId} -import java.util.UUID -import scala.annotation.tailrec -import scala.collection.{immutable, mutable} +import zio.http._ +import zio.http.codec.HttpCodec.Metadata +import zio.http.codec._ +import zio.http.endpoint._ +import zio.http.endpoint.openapi.JsonSchema.SchemaStyle object OpenAPIGen { private val PathWildcard = "pathWildcard"