diff --git a/core/src/main/scala-3/sttp/tapir/internal/SNameMacros.scala b/core/src/main/scala-3/sttp/tapir/internal/SNameMacros.scala index b7b17f15ea..f9cd16e229 100644 --- a/core/src/main/scala-3/sttp/tapir/internal/SNameMacros.scala +++ b/core/src/main/scala-3/sttp/tapir/internal/SNameMacros.scala @@ -7,7 +7,12 @@ import sttp.tapir.SchemaType.* object SNameMacros { inline def typeFullName[T] = ${ typeFullNameImpl[T] } - private def typeFullNameImpl[T: Type](using q: Quotes) = { + private def typeFullNameImpl[T: Type](using q: Quotes): Expr[String] = { + import q.reflect.* + val tpe = TypeRepr.of[T] + Expr(typeFullNameFromTpe(tpe)) + } + def typeFullNameFromTpe(using q: Quotes)(tpe: q.reflect.TypeRepr): String = { import q.reflect.* def normalizedName(s: Symbol): String = if s.flags.is(Flags.Module) then s.name.stripSuffix("$") else s.name @@ -18,9 +23,7 @@ object SNameMacros { else if sym == defn.RootClass then List.empty else nameChain(sym.owner) :+ normalizedName(sym) - val tpe = TypeRepr.of[T] - - Expr(nameChain(tpe.typeSymbol).mkString(".")) + nameChain(tpe.typeSymbol).mkString(".") } def extractTypeArguments(using q: Quotes)(tpe: q.reflect.TypeRepr): List[String] = { diff --git a/core/src/main/scala-3/sttp/tapir/macros/SchemaMacros.scala b/core/src/main/scala-3/sttp/tapir/macros/SchemaMacros.scala index 83312ee972..90c760a2f9 100644 --- a/core/src/main/scala-3/sttp/tapir/macros/SchemaMacros.scala +++ b/core/src/main/scala-3/sttp/tapir/macros/SchemaMacros.scala @@ -1,9 +1,10 @@ package sttp.tapir.macros -import sttp.tapir.{Validator, Schema, SchemaAnnotations, SchemaType} +import sttp.tapir.{Schema, SchemaAnnotations, SchemaType, Validator} import sttp.tapir.SchemaType.SchemaWithValue import sttp.tapir.generic.Configuration -import magnolia1._ +import magnolia1.* +import sttp.tapir.Schema.SName import sttp.tapir.generic.auto.SchemaMagnoliaDerivation import scala.quoted.* @@ -131,6 +132,11 @@ trait SchemaCompanionMacros extends SchemaMagnoliaDerivation { */ inline def oneOfWrapped[E](implicit conf: Configuration): Schema[E] = ${ SchemaCompanionMacros.generateOneOfWrapped[E]('conf) } + /** Derives the schema for a union type `E`. Schemas for all components of the union type must be available in the implicit scope at the + * point of invocation. + */ + inline def derivedUnion[E]: Schema[E] = ${ SchemaCompanionMacros.derivedUnion[E] } + /** Create a schema for an [[Enumeration]], where the validator is created using the enumeration's values. The low-level representation of * the enum is a `String`, and the enum values in the documentation will be encoded using `.toString`. */ @@ -337,4 +343,115 @@ private[tapir] object SchemaCompanionMacros { report.errorAndAbort(s"Can only derive Schema for values owned by scala.Enumeration") } } + + def derivedUnion[T: Type](using q: Quotes): Expr[Schema[T]] = { + import q.reflect.* + + val tpe = TypeRepr.of[T] + def typeParams = SNameMacros.extractTypeArguments(tpe) + + // first, finding all of the components of the union type + def findOrTypes(t: TypeRepr, failIfNotOrType: Boolean = true): List[TypeRepr] = + t.dealias match { + // only failing if the top-level type is not an OrType + case OrType(l, r) => findOrTypes(l, failIfNotOrType = false) ++ findOrTypes(r, failIfNotOrType = false) + case _ if failIfNotOrType => + report.errorAndAbort(s"Can only derive Schemas for union types, got: ${tpe.show}") + case _ => List(t) + } + + val orTypes = findOrTypes(tpe) + + // then, looking up schemas for each of the components + val schemas: List[Expr[Schema[_]]] = orTypes.map { orType => + orType.asType match { + case '[f] => + Expr.summon[Schema[f]] match { + case Some(subSchema) => subSchema + case None => + val typeName = TypeRepr.of[f].show + report.errorAndAbort(s"Cannot summon schema for `$typeName`. Make sure schema derivation is properly configured.") + } + } + } + + // then, constructing the name of the schema; if the type is not named, we generate a name by hand by concatenating + // names of the components + val orTypesNames = Expr.ofList(orTypes.map { orType => + orType.asType match { + case '[f] => + val typeParams = SNameMacros.extractTypeArguments(orType) + '{ _root_.sttp.tapir.Schema.SName(SNameMacros.typeFullName[f], ${ Expr(typeParams) }) } + } + }) + + val baseName = SNameMacros.typeFullNameFromTpe(tpe) + val snameExpr = if baseName.isEmpty then '{ SName(${ orTypesNames }.map(_.show).mkString("_or_")) } + else '{ SName(${ Expr(baseName) }, ${ Expr(typeParams) }) } + + // then, generating the method which maps a specific value to a schema, trying to match to one of the components + val typesAndSchemas = orTypes.zip(schemas) // both lists have the same length + def subtypeSchema(e: Expr[T]) = { + val eIdent = e.asTerm match { + case Inlined(_, _, ei: Ident) => ei + case ei: Ident => ei + } + + // if an or-type component that is generic appears more than once, we won't be able to perform a runtime check, + // to get the correct schema; in such case, instead of generating a `case ...`, we add a (single!) + // `case _ => None` to the match + val genericTypesThatAppearMoreThanOnce = { + var seen = Set[String]() + var result = Set[String]() + + orTypes.foreach { orType => + orType.classSymbol match { + case Some(sym) if orType.typeArgs.nonEmpty => // is generic + if seen.contains(sym.fullName) then result = result + sym.fullName + else seen = seen + sym.fullName + case _ => // skip + } + } + + result + } + + val baseCases = typesAndSchemas.flatMap { (orType, orTypeSchema) => + def caseThen = Block(Nil, '{ Some(SchemaWithValue($orTypeSchema.asInstanceOf[Schema[Any]], $e)) }.asTerm) + + orType.classSymbol match + case None => Some(CaseDef(Ident(orType.termSymbol.termRef), None, caseThen)) + case Some(sym) if orType.typeArgs.nonEmpty => + if genericTypesThatAppearMoreThanOnce.contains(sym.fullName) then None + else + val wildcardTypeParameters: List[Tree] = + List.fill(orType.typeArgs.length)(TypeBoundsTree(TypeTree.of[Nothing], TypeTree.of[Any])) + Some(CaseDef(Typed(Wildcard(), Applied(TypeIdent(sym), wildcardTypeParameters)), None, caseThen)) + case Some(sym) => Some(CaseDef(Typed(Wildcard(), TypeIdent(sym)), None, caseThen)) + } + val cases = + if genericTypesThatAppearMoreThanOnce.nonEmpty + then baseCases :+ CaseDef(Wildcard(), None, Block(Nil, '{ None }.asTerm)) + else baseCases + val t = Match(eIdent, cases) + + t.asExprOf[Option[SchemaWithValue[_]]] + } + + // finally, generating code which creates the SCoproduct + '{ + import _root_.sttp.tapir.Schema + import _root_.sttp.tapir.Schema._ + import _root_.sttp.tapir.SchemaType._ + import _root_.scala.collection.immutable.{List, Map} + + val childSchemas = List(${ Varargs(schemas) }: _*) + val sname = $snameExpr + + Schema( + schemaType = SCoproduct[T](childSchemas, None) { e => ${ subtypeSchema('{ e }) } }, + name = Some(sname) + ) + } + } } diff --git a/core/src/main/scala/sttp/tapir/Schema.scala b/core/src/main/scala/sttp/tapir/Schema.scala index ebf04eae1d..b00b02bc26 100644 --- a/core/src/main/scala/sttp/tapir/Schema.scala +++ b/core/src/main/scala/sttp/tapir/Schema.scala @@ -333,7 +333,7 @@ object Schema extends LowPrioritySchema with SchemaCompanionMacros { } case class SName(fullName: String, typeParameterShortNames: List[String] = Nil) { - def show: String = fullName + typeParameterShortNames.mkString("[", ",", "]") + def show: String = fullName + (if (typeParameterShortNames.isEmpty) "" else typeParameterShortNames.mkString("[", ",", "]")) } object SName { val Unit: SName = SName(fullName = "Unit") diff --git a/core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala b/core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala index 81db59105c..150aeab60a 100644 --- a/core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala +++ b/core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala @@ -2,13 +2,12 @@ package sttp.tapir import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers - -import scala.concurrent.duration.Duration +import sttp.tapir.internal.SNameMacros class SchemaMacroScala3Test extends AnyFlatSpec with Matchers: import SchemaMacroScala3Test._ - it should "validate enum" in { + it should "derive a one-of-wrapped schema for enums" in { given s: Schema[Fruit] = Schema.oneOfWrapped s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => } @@ -18,6 +17,79 @@ class SchemaMacroScala3Test extends AnyFlatSpec with Matchers: coproduct.subtypeSchema(Fruit.Apple).map(_.value) shouldBe Some(Fruit.Apple) } + it should "derive schema for union types" in { + // when + val s: Schema[String | Int] = Schema.derivedUnion + + // then + s.name.map(_.show) shouldBe Some("java.lang.String_or_scala.Int") + + s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => } + val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[String | Int]] + coproduct.subtypes should have size 2 + coproduct.subtypeSchema("a").map(_.schema.schemaType) shouldBe Some(SchemaType.SString()) + coproduct.subtypeSchema(10).map(_.schema.schemaType) shouldBe Some(SchemaType.SInteger()) + } + + it should "derive schema for a named union type" in { + // when + val s: Schema[StringOrInt] = Schema.derivedUnion[StringOrInt] + + // then + s.name.map(_.show) shouldBe Some("sttp.tapir.SchemaMacroScala3Test.StringOrInt") + + s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => } + val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[StringOrInt]] + coproduct.subtypes should have size 2 + coproduct.subtypeSchema("a").map(_.schema.schemaType) shouldBe Some(SchemaType.SString()) + coproduct.subtypeSchema(10).map(_.schema.schemaType) shouldBe Some(SchemaType.SInteger()) + } + + it should "derive schema for a union type with generics (same type constructor, different arguments)" in { + // when + val s: Schema[List[String] | List[Int]] = Schema.derivedUnion[List[String] | List[Int]] + + // then + s.name.map(_.show) shouldBe Some("scala.collection.immutable.List[String]_or_scala.collection.immutable.List[Int]") + + s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => } + val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[List[String] | List[Int]]] + coproduct.subtypes should have size 2 + // no subtype schemas for generic types, as there's no runtime tag + coproduct.subtypeSchema(List("")).map(_.schema.schemaType) shouldBe None + } + + it should "derive schema for a union type with generics (different type constructors)" in { + // when + val s: Schema[List[String] | Vector[Int]] = Schema.derivedUnion[List[String] | Vector[Int]] + + // then + s.name.map(_.show) shouldBe Some("scala.collection.immutable.List[String]_or_scala.collection.immutable.Vector[Int]") + + s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => } + val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[List[String] | Vector[Int]]] + coproduct.subtypes should have size 2 + coproduct.subtypeSchema(List("")).map(_.schema.schemaType) should matchPattern { case Some(_) => } + coproduct.subtypeSchema(Vector(10)).map(_.schema.schemaType) should matchPattern { case Some(_) => } + } + + it should "derive schema for union types with 3 components" in { + // when + val s: Schema[String | Int | Boolean] = Schema.derivedUnion + + // then + s.name.map(_.show) shouldBe Some("java.lang.String_or_scala.Int_or_scala.Boolean") + + s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => } + val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[String | Int | Boolean]] + coproduct.subtypes should have size 3 + coproduct.subtypeSchema("a").map(_.schema.schemaType) shouldBe Some(SchemaType.SString()) + coproduct.subtypeSchema(10).map(_.schema.schemaType) shouldBe Some(SchemaType.SInteger()) + coproduct.subtypeSchema(true).map(_.schema.schemaType) shouldBe Some(SchemaType.SBoolean()) + } + object SchemaMacroScala3Test: enum Fruit: case Apple, Banana + + type StringOrInt = String | Int diff --git a/doc/endpoint/schemas.md b/doc/endpoint/schemas.md index e37b014ea5..133cbb8178 100644 --- a/doc/endpoint/schemas.md +++ b/doc/endpoint/schemas.md @@ -81,7 +81,9 @@ integration layer. This method may be used both with automatic and semi-automatic derivation. -## Derivation for recursive types in Scala3 +## Scala3-specific derivation + +### Derivation for recursive types In Scala3, any schemas for recursive types need to be provided as typed `implicit def` (not a `given`)! For example: @@ -96,6 +98,25 @@ object RecursiveTest { The implicit doesn't have to be defined in the companion object, just anywhere in scope. This applies to cases where the schema is looked up implicitly, e.g. for `jsonBody`. +### Derivation for union types + +Schemas for union types must be declared by hand, using the `Schema.derivedUnion[T]` method. Schemas for all components +of the union type must be available in the implicit scope at the point of invocation. For example: + +```scala +val s: Schema[String | Int] = Schema.derivedUnion +``` + +If the union type is a named alias, the type needs to be provided explicitly, e.g.: + +```scala +type StringOrInt = String | Int +val s: Schema[StringOrInt] = Schema.derivedUnion[StringOrInt] +``` + +If any of the components of the union type is a generic type, any of its validations will be skipped when validating +the union type, as it's not possible to generate a runtime check for the generic type. + ## Configuring derivation It is possible to configure Magnolia's automatic derivation to use `snake_case`, `kebab-case` or a custom field naming