Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add derivation for schemas of union types #3425

Merged
merged 2 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions core/src/main/scala-3/sttp/tapir/internal/SNameMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import sttp.tapir.SchemaType.*
object SNameMacros {

inline def typeFullName[T] = ${ typeFullNameImpl[T] }
private def typeFullNameImpl[T: Type](using q: Quotes) = {
private def typeFullNameImpl[T: Type](using q: Quotes): Expr[String] = {
import q.reflect.*
val tpe = TypeRepr.of[T]
Expr(typeFullNameFromTpe(tpe))
}
def typeFullNameFromTpe(using q: Quotes)(tpe: q.reflect.TypeRepr): String = {
import q.reflect.*

def normalizedName(s: Symbol): String = if s.flags.is(Flags.Module) then s.name.stripSuffix("$") else s.name
Expand All @@ -18,9 +23,7 @@ object SNameMacros {
else if sym == defn.RootClass then List.empty
else nameChain(sym.owner) :+ normalizedName(sym)

val tpe = TypeRepr.of[T]

Expr(nameChain(tpe.typeSymbol).mkString("."))
nameChain(tpe.typeSymbol).mkString(".")
}

def extractTypeArguments(using q: Quotes)(tpe: q.reflect.TypeRepr): List[String] = {
Expand Down
121 changes: 119 additions & 2 deletions core/src/main/scala-3/sttp/tapir/macros/SchemaMacros.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package sttp.tapir.macros

import sttp.tapir.{Validator, Schema, SchemaAnnotations, SchemaType}
import sttp.tapir.{Schema, SchemaAnnotations, SchemaType, Validator}
import sttp.tapir.SchemaType.SchemaWithValue
import sttp.tapir.generic.Configuration
import magnolia1._
import magnolia1.*
import sttp.tapir.Schema.SName
import sttp.tapir.generic.auto.SchemaMagnoliaDerivation

import scala.quoted.*
Expand Down Expand Up @@ -131,6 +132,11 @@ trait SchemaCompanionMacros extends SchemaMagnoliaDerivation {
*/
inline def oneOfWrapped[E](implicit conf: Configuration): Schema[E] = ${ SchemaCompanionMacros.generateOneOfWrapped[E]('conf) }

/** Derives the schema for a union type `E`. Schemas for all components of the union type must be available in the implicit scope at the
* point of invocation.
*/
inline def derivedUnion[E]: Schema[E] = ${ SchemaCompanionMacros.derivedUnion[E] }

/** Create a schema for an [[Enumeration]], where the validator is created using the enumeration's values. The low-level representation of
* the enum is a `String`, and the enum values in the documentation will be encoded using `.toString`.
*/
Expand Down Expand Up @@ -337,4 +343,115 @@ private[tapir] object SchemaCompanionMacros {
report.errorAndAbort(s"Can only derive Schema for values owned by scala.Enumeration")
}
}

def derivedUnion[T: Type](using q: Quotes): Expr[Schema[T]] = {
import q.reflect.*

val tpe = TypeRepr.of[T]
def typeParams = SNameMacros.extractTypeArguments(tpe)

// first, finding all of the components of the union type
def findOrTypes(t: TypeRepr, failIfNotOrType: Boolean = true): List[TypeRepr] =
t.dealias match {
// only failing if the top-level type is not an OrType
case OrType(l, r) => findOrTypes(l, failIfNotOrType = false) ++ findOrTypes(r, failIfNotOrType = false)
case _ if failIfNotOrType =>
report.errorAndAbort(s"Can only derive Schemas for union types, got: ${tpe.show}")
case _ => List(t)
}

val orTypes = findOrTypes(tpe)

// then, looking up schemas for each of the components
val schemas: List[Expr[Schema[_]]] = orTypes.map { orType =>
orType.asType match {
case '[f] =>
Expr.summon[Schema[f]] match {
case Some(subSchema) => subSchema
case None =>
val typeName = TypeRepr.of[f].show
report.errorAndAbort(s"Cannot summon schema for `$typeName`. Make sure schema derivation is properly configured.")
}
}
}

// then, constructing the name of the schema; if the type is not named, we generate a name by hand by concatenating
// names of the components
val orTypesNames = Expr.ofList(orTypes.map { orType =>
orType.asType match {
case '[f] =>
val typeParams = SNameMacros.extractTypeArguments(orType)
'{ _root_.sttp.tapir.Schema.SName(SNameMacros.typeFullName[f], ${ Expr(typeParams) }) }
}
})

val baseName = SNameMacros.typeFullNameFromTpe(tpe)
val snameExpr = if baseName.isEmpty then '{ SName(${ orTypesNames }.map(_.show).mkString("_or_")) }
else '{ SName(${ Expr(baseName) }, ${ Expr(typeParams) }) }

// then, generating the method which maps a specific value to a schema, trying to match to one of the components
val typesAndSchemas = orTypes.zip(schemas) // both lists have the same length
def subtypeSchema(e: Expr[T]) = {
val eIdent = e.asTerm match {
case Inlined(_, _, ei: Ident) => ei
case ei: Ident => ei
}

// if an or-type component that is generic appears more than once, we won't be able to perform a runtime check,
// to get the correct schema; in such case, instead of generating a `case ...`, we add a (single!)
// `case _ => None` to the match
val genericTypesThatAppearMoreThanOnce = {
var seen = Set[String]()
var result = Set[String]()

orTypes.foreach { orType =>
orType.classSymbol match {
case Some(sym) if orType.typeArgs.nonEmpty => // is generic
if seen.contains(sym.fullName) then result = result + sym.fullName
else seen = seen + sym.fullName
case _ => // skip
}
}

result
}

val baseCases = typesAndSchemas.flatMap { (orType, orTypeSchema) =>
def caseThen = Block(Nil, '{ Some(SchemaWithValue($orTypeSchema.asInstanceOf[Schema[Any]], $e)) }.asTerm)

orType.classSymbol match
case None => Some(CaseDef(Ident(orType.termSymbol.termRef), None, caseThen))
case Some(sym) if orType.typeArgs.nonEmpty =>
if genericTypesThatAppearMoreThanOnce.contains(sym.fullName) then None
else
val wildcardTypeParameters: List[Tree] =
List.fill(orType.typeArgs.length)(TypeBoundsTree(TypeTree.of[Nothing], TypeTree.of[Any]))
Some(CaseDef(Typed(Wildcard(), Applied(TypeIdent(sym), wildcardTypeParameters)), None, caseThen))
case Some(sym) => Some(CaseDef(Typed(Wildcard(), TypeIdent(sym)), None, caseThen))
}
val cases =
if genericTypesThatAppearMoreThanOnce.nonEmpty
then baseCases :+ CaseDef(Wildcard(), None, Block(Nil, '{ None }.asTerm))
else baseCases
val t = Match(eIdent, cases)

t.asExprOf[Option[SchemaWithValue[_]]]
}

// finally, generating code which creates the SCoproduct
'{
import _root_.sttp.tapir.Schema
import _root_.sttp.tapir.Schema._
import _root_.sttp.tapir.SchemaType._
import _root_.scala.collection.immutable.{List, Map}

val childSchemas = List(${ Varargs(schemas) }: _*)
val sname = $snameExpr

Schema(
schemaType = SCoproduct[T](childSchemas, None) { e => ${ subtypeSchema('{ e }) } },
name = Some(sname)
)
}
}
}
2 changes: 1 addition & 1 deletion core/src/main/scala/sttp/tapir/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ object Schema extends LowPrioritySchema with SchemaCompanionMacros {
}

case class SName(fullName: String, typeParameterShortNames: List[String] = Nil) {
def show: String = fullName + typeParameterShortNames.mkString("[", ",", "]")
def show: String = fullName + (if (typeParameterShortNames.isEmpty) "" else typeParameterShortNames.mkString("[", ",", "]"))
}
object SName {
val Unit: SName = SName(fullName = "Unit")
Expand Down
78 changes: 75 additions & 3 deletions core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package sttp.tapir

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.concurrent.duration.Duration
import sttp.tapir.internal.SNameMacros

class SchemaMacroScala3Test extends AnyFlatSpec with Matchers:
import SchemaMacroScala3Test._

it should "validate enum" in {
it should "derive a one-of-wrapped schema for enums" in {
given s: Schema[Fruit] = Schema.oneOfWrapped

s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => }
Expand All @@ -18,6 +17,79 @@ class SchemaMacroScala3Test extends AnyFlatSpec with Matchers:
coproduct.subtypeSchema(Fruit.Apple).map(_.value) shouldBe Some(Fruit.Apple)
}

it should "derive schema for union types" in {
// when
val s: Schema[String | Int] = Schema.derivedUnion

// then
s.name.map(_.show) shouldBe Some("java.lang.String_or_scala.Int")

s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => }
val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[String | Int]]
coproduct.subtypes should have size 2
coproduct.subtypeSchema("a").map(_.schema.schemaType) shouldBe Some(SchemaType.SString())
coproduct.subtypeSchema(10).map(_.schema.schemaType) shouldBe Some(SchemaType.SInteger())
}

it should "derive schema for a named union type" in {
// when
val s: Schema[StringOrInt] = Schema.derivedUnion[StringOrInt]

// then
s.name.map(_.show) shouldBe Some("sttp.tapir.SchemaMacroScala3Test.StringOrInt")

s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => }
val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[StringOrInt]]
coproduct.subtypes should have size 2
coproduct.subtypeSchema("a").map(_.schema.schemaType) shouldBe Some(SchemaType.SString())
coproduct.subtypeSchema(10).map(_.schema.schemaType) shouldBe Some(SchemaType.SInteger())
}

it should "derive schema for a union type with generics (same type constructor, different arguments)" in {
// when
val s: Schema[List[String] | List[Int]] = Schema.derivedUnion[List[String] | List[Int]]

// then
s.name.map(_.show) shouldBe Some("scala.collection.immutable.List[String]_or_scala.collection.immutable.List[Int]")

s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => }
val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[List[String] | List[Int]]]
coproduct.subtypes should have size 2
// no subtype schemas for generic types, as there's no runtime tag
coproduct.subtypeSchema(List("")).map(_.schema.schemaType) shouldBe None
}

it should "derive schema for a union type with generics (different type constructors)" in {
// when
val s: Schema[List[String] | Vector[Int]] = Schema.derivedUnion[List[String] | Vector[Int]]

// then
s.name.map(_.show) shouldBe Some("scala.collection.immutable.List[String]_or_scala.collection.immutable.Vector[Int]")

s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => }
val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[List[String] | Vector[Int]]]
coproduct.subtypes should have size 2
coproduct.subtypeSchema(List("")).map(_.schema.schemaType) should matchPattern { case Some(_) => }
coproduct.subtypeSchema(Vector(10)).map(_.schema.schemaType) should matchPattern { case Some(_) => }
}

it should "derive schema for union types with 3 components" in {
// when
val s: Schema[String | Int | Boolean] = Schema.derivedUnion

// then
s.name.map(_.show) shouldBe Some("java.lang.String_or_scala.Int_or_scala.Boolean")

s.schemaType should matchPattern { case SchemaType.SCoproduct(_, _) => }
val coproduct = s.schemaType.asInstanceOf[SchemaType.SCoproduct[String | Int | Boolean]]
coproduct.subtypes should have size 3
coproduct.subtypeSchema("a").map(_.schema.schemaType) shouldBe Some(SchemaType.SString())
coproduct.subtypeSchema(10).map(_.schema.schemaType) shouldBe Some(SchemaType.SInteger())
coproduct.subtypeSchema(true).map(_.schema.schemaType) shouldBe Some(SchemaType.SBoolean())
}

object SchemaMacroScala3Test:
enum Fruit:
case Apple, Banana

type StringOrInt = String | Int
23 changes: 22 additions & 1 deletion doc/endpoint/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ integration layer.

This method may be used both with automatic and semi-automatic derivation.

## Derivation for recursive types in Scala3
## Scala3-specific derivation

### Derivation for recursive types

In Scala3, any schemas for recursive types need to be provided as typed `implicit def` (not a `given`)!
For example:
Expand All @@ -96,6 +98,25 @@ object RecursiveTest {
The implicit doesn't have to be defined in the companion object, just anywhere in scope. This applies to cases where
the schema is looked up implicitly, e.g. for `jsonBody`.

### Derivation for union types

Schemas for union types must be declared by hand, using the `Schema.derivedUnion[T]` method. Schemas for all components
of the union type must be available in the implicit scope at the point of invocation. For example:

```scala
val s: Schema[String | Int] = Schema.derivedUnion
```

If the union type is a named alias, the type needs to be provided explicitly, e.g.:

```scala
type StringOrInt = String | Int
val s: Schema[StringOrInt] = Schema.derivedUnion[StringOrInt]
```

If any of the components of the union type is a generic type, any of its validations will be skipped when validating
the union type, as it's not possible to generate a runtime check for the generic type.

## Configuring derivation

It is possible to configure Magnolia's automatic derivation to use `snake_case`, `kebab-case` or a custom field naming
Expand Down
Loading