diff --git a/modules/core/src/main/scala/schema.scala b/modules/core/src/main/scala/schema.scala index 4d5d751a..025f511f 100644 --- a/modules/core/src/main/scala/schema.scala +++ b/modules/core/src/main/scala/schema.scala @@ -6,11 +6,12 @@ package edu.gemini.grackle import atto.Atto._ import cats.data.Ior import cats.implicits._ +import edu.gemini.grackle.Ast.{ObjectTypeDefinition, TypeDefinition} +import edu.gemini.grackle.QueryInterpreter.{mkErrorResult, mkOneError} +import edu.gemini.grackle.ScalarType._ +import edu.gemini.grackle.Value._ import io.circe.Json -import QueryInterpreter.{ mkErrorResult, mkOneError } -import ScalarType._, Value._ - /** * Representation of a GraphQL schema * @@ -19,13 +20,14 @@ import ScalarType._, Value._ trait Schema { /** The types defined by this `Schema`. */ def types: List[NamedType] + /** The directives defined by this `Schema`. */ def directives: List[Directive] /** A reference by name to a type defined by this `Schema`. * - * `TypeRef`s refer to types defined in this schema by name and hence - * can be used as part of mutually recursive type definitions. + * `TypeRef`s refer to types defined in this schema by name and hence + * can be used as part of mutually recursive type definitions. */ def ref(tpnme: String): TypeRef = new TypeRef(this, tpnme) @@ -43,10 +45,10 @@ trait Schema { * * ``` * type Schema { - * query: Query! - * mutation: Mutation - * subscription: Subscription - * } + * query: Query! + * mutation: Mutation + * subscription: Subscription + * } * ``` * * is used. @@ -89,11 +91,13 @@ trait Schema { */ def schemaType: NamedType = definition("Schema").getOrElse(defaultSchemaType) - /** The type of queries defined by this `Schema` */ + /** The type of queries defined by this `Schema`*/ def queryType: NamedType = schemaType.field("query").asNamed.get - /** The type of mutations defined by this `Schema` */ + + /** The type of mutations defined by this `Schema`*/ def mutationType: Option[NamedType] = schemaType.field("mutation").asNamed - /** The type of subscriptions defined by this `Schema` */ + + /** The type of subscriptions defined by this `Schema`*/ def subscriptionType: Option[NamedType] = schemaType.field("subscription").asNamed override def toString = SchemaRenderer.renderSchema(this) @@ -120,8 +124,8 @@ sealed trait Type { /** `true` if this type is a subtype of `other`. */ def <:<(other: Type): Boolean = (this.dealias, other.dealias) match { - case (tp1, tp2) if tp1 == tp2 => true - case (tp1, UnionType(_, _, members)) => members.exists(tp1 <:< _.dealias) + case (tp1, tp2) if tp1 == tp2 => true + case (tp1, UnionType(_, _, members)) => members.exists(tp1 <:< _.dealias) case (ObjectType(_, _, _, interfaces), tp2) => interfaces.exists(_.dealias == tp2) case _ => false } @@ -357,10 +361,15 @@ sealed trait Type { sealed trait NamedType extends Type { /** The name of this type */ def name: String + override def dealias: NamedType = this + override def isNamed: Boolean = true + override def asNamed: Option[NamedType] = Some(this) + def description: Option[String] + override def toString: String = name } @@ -382,23 +391,27 @@ object JoinType { Some((fieldName, tpe)) } } + /** * A by name reference to a type defined in `schema`. */ case class TypeRef(schema: Schema, name: String) extends NamedType { override def dealias: NamedType = schema.definition(name).getOrElse(this) + override def exists: Boolean = schema.definition(name).isDefined + def description: Option[String] = dealias.description } /** * Represents scalar types such as Int, String, and Boolean. Scalars cannot have fields. + * * @see https://facebook.github.io/graphql/draft/#sec-Scalar */ case class ScalarType( - name: String, - description: Option[String] -) extends Type with NamedType + name: String, + description: Option[String] + ) extends Type with NamedType object ScalarType { def apply(tpnme: String): Option[ScalarType] = tpnme match { @@ -477,87 +490,95 @@ sealed trait TypeWithFields extends NamedType { /** * Interfaces are an abstract type where there are common fields declared. Any type that * implements an interface must define all the fields with names and types exactly matching. + * * @see https://facebook.github.io/graphql/draft/#sec-Interface */ case class InterfaceType( - name: String, - description: Option[String], - fields: List[Field], - interfaces: List[NamedType] -) extends Type with TypeWithFields { + name: String, + description: Option[String], + fields: List[Field], + interfaces: List[NamedType] + ) extends Type with TypeWithFields { override def isInterface: Boolean = true } /** * Object types represent concrete instantiations of sets of fields. + * * @see https://facebook.github.io/graphql/draft/#sec-Object */ case class ObjectType( - name: String, - description: Option[String], - fields: List[Field], - interfaces: List[NamedType] -) extends Type with TypeWithFields + name: String, + description: Option[String], + fields: List[Field], + interfaces: List[NamedType] + ) extends Type with TypeWithFields /** * Unions are an abstract type where no common fields are declared. The possible types of a union * are explicitly listed out in elements. Types can be made parts of unions without * modification of that type. + * * @see https://facebook.github.io/graphql/draft/#sec-Union */ case class UnionType( - name: String, - description: Option[String], - members: List[NamedType] -) extends Type with NamedType { + name: String, + description: Option[String], + members: List[NamedType] + ) extends Type with NamedType { override def toString: String = members.mkString("|") } /** * Enums are special scalars that can only have a defined set of values. + * * @see https://facebook.github.io/graphql/draft/#sec-Enum */ case class EnumType( - name: String, - description: Option[String], - enumValues: List[EnumValue] -) extends Type with NamedType { + name: String, + description: Option[String], + enumValues: List[EnumValue] + ) extends Type with NamedType { def hasValue(name: String): Boolean = enumValues.exists(_.name == name) + def value(name: String): Option[EnumValue] = enumValues.find(_.name == name) } /** * The `EnumValue` type represents one of possible values of an enum. + * * @see https://facebook.github.io/graphql/draft/#sec-The-__EnumValue-Type */ case class EnumValue( - name: String, - description: Option[String], - isDeprecated: Boolean = false, - deprecationReason: Option[String] = None -) + name: String, + description: Option[String], + isDeprecated: Boolean = false, + deprecationReason: Option[String] = None + ) /** * Input objects are composite types used as inputs into queries defined as a list of named input * values. + * * @see https://facebook.github.io/graphql/draft/#sec-Input-Object */ case class InputObjectType( - name: String, - description: Option[String], - inputFields: List[InputValue] -) extends Type with NamedType { + name: String, + description: Option[String], + inputFields: List[InputValue] + ) extends Type with NamedType { def inputFieldInfo(name: String): Option[InputValue] = inputFields.find(_.name == name) } /** * Lists represent sequences of values in GraphQL. A List type is a type modifier: it wraps * another type instance in the ofType field, which defines the type of each item in the list. + * * @see https://facebook.github.io/graphql/draft/#sec-Type-Kinds.List */ case class ListType( - ofType: Type -) extends Type { + ofType: Type + ) extends Type { override def toString: String = s"[$ofType]" } @@ -565,51 +586,66 @@ case class ListType( * A Non‐null type is a type modifier: it wraps another type instance in the `ofType` field. * Non‐null types do not allow null as a response, and indicate required inputs for arguments * and input object fields. + * * @see https://facebook.github.io/graphql/draft/#sec-Type-Kinds.Non-Null */ case class NullableType( - ofType: Type -) extends Type { + ofType: Type + ) extends Type { override def toString: String = s"$ofType?" } /** * The `Field` type represents each field in an Object or Interface type. + * * @see https://facebook.github.io/graphql/draft/#sec-The-__Field-Type */ -case class Field private ( - name: String, - description: Option[String], - args: List[InputValue], - tpe: Type, - isDeprecated: Boolean, - deprecationReason: Option[String] -) +case class Field private( + name: String, + description: Option[String], + args: List[InputValue], + tpe: Type, + isDeprecated: Boolean, + deprecationReason: Option[String] + ) /** - * @param defaultValue a String encoding (using the GraphQL language) of the default value used by - * this input value in the condition a value is not provided at runtime. + * @param defaultValue a String encoding (using the GraphQL language) of the default value used by + * this input value in the condition a value is not provided at runtime. */ -case class InputValue private ( - name: String, - description: Option[String], - tpe: Type, - defaultValue: Option[Value] -) +case class InputValue private( + name: String, + description: Option[String], + tpe: Type, + defaultValue: Option[Value] + ) sealed trait Value + object Value { + case class IntValue(value: Int) extends Value + case class FloatValue(value: Double) extends Value + case class StringValue(value: String) extends Value + case class BooleanValue(value: Boolean) extends Value + case class IDValue(value: String) extends Value + case class UntypedEnumValue(name: String) extends Value + case class TypedEnumValue(value: EnumValue) extends Value + case class UntypedVariableValue(name: String) extends Value + case class ListValue(elems: List[Value]) extends Value + case class ObjectValue(fields: List[(String, Value)]) extends Value + case object NullValue extends Value + case object AbsentValue extends Value def checkValue(iv: InputValue, value: Option[Value]): Result[Value] = @@ -693,32 +729,35 @@ object Value { /** * The `Directive` type represents a Directive that a server supports. + * * @see https://facebook.github.io/graphql/draft/#sec-The-__Directive-Type */ case class Directive( - name: String, - description: Option[String], - locations: List[Ast.DirectiveLocation], - args: List[InputValue] -) + name: String, + description: Option[String], + locations: List[Ast.DirectiveLocation], + args: List[InputValue] + ) /** * GraphQL schema parser */ object SchemaParser { - import Ast.{ Type => _, Value => _, Directive => DefinedDirective, _ }, OperationType._ + + import Ast.{Directive => DefinedDirective, Type => _, Value => _, _} + import OperationType._ /** - * Parse a query String to a query algebra term. + * Parse a query String to a query algebra term. * - * Yields a Query value on the right and accumulates errors on the left. + * Yields a Query value on the right and accumulates errors on the left. */ def parseText(text: String): Result[Schema] = { def toResult[T](pr: Either[String, T]): Result[T] = Ior.fromEither(pr).leftMap(mkOneError(_)) for { - doc <- toResult(GraphQLParser.Document.parseOnly(text).either) + doc <- toResult(GraphQLParser.Document.parseOnly(text).either) query <- parseDocument(doc) } yield query } @@ -742,10 +781,10 @@ object SchemaParser { description = None, fields = mkRootDef("query")(query) :: - List( - mutation.map(mkRootDef("mutation")), - subscription.map(mkRootDef("subscription")) - ).flatten, + List( + mutation.map(mkRootDef("mutation")), + subscription.map(mkRootDef("subscription")) + ).flatten, interfaces = Nil ) } @@ -766,8 +805,10 @@ object SchemaParser { } def mkTypeDefs(schema: Schema): Result[List[NamedType]] = { - val defns = doc.collect { case tpe: TypeDefinition => tpe } - defns.traverse(mkTypeDef(schema)) + val defns: List[TypeDefinition] = doc.collect { case tpe: TypeDefinition => tpe } + val namedTypeResults = defns.traverse(mkTypeDef(schema)) + + SchemaValidator.validateSchema(namedTypeResults, defns) } def mkTypeDef(schema: Schema)(td: TypeDefinition): Result[NamedType] = td match { @@ -780,12 +821,12 @@ object SchemaParser { case ObjectTypeDefinition(Name(nme), desc, fields0, ifs0, _) => for { fields <- fields0.traverse(mkField(schema)) - ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } + ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } } yield ObjectType(nme, desc, fields, ifs) case InterfaceTypeDefinition(Name(nme), desc, fields0, ifs0, _) => for { fields <- fields0.traverse(mkField(schema)) - ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } + ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } } yield InterfaceType(nme, desc, fields, ifs) case UnionTypeDefinition(Name(nme), desc, _, members0) => val members = members0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } @@ -804,7 +845,7 @@ object SchemaParser { val FieldDefinition(Name(nme), desc, args0, tpe0, dirs) = f for { args <- args0.traverse(mkInputValue(schema)) - tpe <- mkType(schema)(tpe0) + tpe <- mkType(schema)(tpe0) deprecation <- parseDeprecated(dirs) (isDeprecated, reason) = deprecation } yield Field(nme, desc, args, tpe, isDeprecated, reason) @@ -813,6 +854,7 @@ object SchemaParser { def mkType(schema: Schema)(tpe: Ast.Type): Result[Type] = { def loop(tpe: Ast.Type, nullable: Boolean): Result[Type] = { def wrap(tpe: Type): Type = if (nullable) NullableType(tpe) else tpe + tpe match { case Ast.Type.List(tpe) => loop(tpe, true).map(tpe => wrap(ListType(tpe))) case Ast.Type.NonNull(Left(tpe)) => loop(tpe, false) @@ -820,13 +862,14 @@ object SchemaParser { case Ast.Type.Named(Name(nme)) => wrap(ScalarType(nme).getOrElse(schema.ref(nme))).rightIor } } + loop(tpe, true) } def mkInputValue(schema: Schema)(f: InputValueDefinition): Result[InputValue] = { val InputValueDefinition(Name(nme), desc, tpe0, default0, _) = f for { - tpe <- mkType(schema)(tpe0) + tpe <- mkType(schema)(tpe0) dflt <- default0.traverse(parseValue) } yield InputValue(nme, desc, tpe, dflt) } @@ -869,7 +912,9 @@ object SchemaParser { object schema extends Schema { var types: List[NamedType] = Nil var schemaType1: NamedType = null + override def schemaType: NamedType = schemaType1 + var directives: List[Directive] = Nil def complete(types0: List[NamedType], schemaType0: NamedType, directives0: List[Directive]): this.type = { @@ -881,12 +926,60 @@ object SchemaParser { } for { - types <- mkTypeDefs(schema) + types <- mkTypeDefs(schema) schemaType <- mkSchemaType(schema) } yield schema.complete(types, schemaType, Nil) } } +object SchemaValidator { + def validateSchema(namedTypes: Result[List[NamedType]], defns: List[TypeDefinition]): Result[List[NamedType]] = + checkForUndefined(checkForDuplicates(namedTypes), defns) + + type NamedTypeWithIndex = (NamedType, Int) + + implicit object NamedTypeOrdering extends Ordering[NamedTypeWithIndex] { + def compare(a: NamedTypeWithIndex, b: NamedTypeWithIndex): Int = a._2 compare b._2 + } + + def dedupedOrError(dupes: Map[String, List[(NamedType, Int)]]): Result[List[NamedTypeWithIndex]] = dupes.map { + case (name, tpe) if tpe.length > 1 => mkErrorResult[NamedTypeWithIndex](s"Duplicate NamedType found: $name") + case (name, typeAndIndex) => typeAndIndex.headOption.map(_.rightIor).getOrElse(mkErrorResult(s"No NamedType found for $name, something has gone wrong.")) + }.toList.sequence + + def checkForDuplicates(namedTypes: Result[List[NamedType]]): Result[List[NamedType]] = + for { + types <- namedTypes + map = types.zipWithIndex.groupBy(_._1.name) + unordered <- dedupedOrError(map) + res = unordered.sorted.map(_._1) + } yield res + + def checkReferencedTypesAgainstDefinedTypes(namedTypes: Result[List[NamedType]], defns: List[TypeDefinition]): Result[List[NamedType]] = { + val defaultTypes = List(StringType, IntType, FloatType, BooleanType, IDType) + + val lefts = namedTypes.flatMap { t => + val typeNames = (t ::: defaultTypes).map(_.name) + + referencedTypes(defns).collect { + case tpe if !typeNames.contains(tpe) => mkErrorResult[NamedType](s"Reference to undefined type: $tpe") + }.sequence + } + namedTypes combine lefts + } + + def referencedTypes(defns: List[TypeDefinition]): List[String] = defns.collect { + case o: ObjectTypeDefinition => + (o.fields.flatMap(_.args.map(_.tpe)) ::: o.fields.map(_.tpe)) + .map(_.name.replaceAll("[\\W]", "")) + }.flatten + + def checkForUndefined(namedTypes: Result[List[NamedType]], defns: List[TypeDefinition]): Result[List[NamedType]] = + for { + res <- checkReferencedTypesAgainstDefinedTypes(namedTypes, defns) + } yield res +} + object SchemaRenderer { def renderSchema(schema: Schema): String = { def mkRootDef(fieldName: String)(tpe: NamedType): String = @@ -894,17 +987,17 @@ object SchemaRenderer { val fields = mkRootDef("query")(schema.queryType) :: - List( - schema.mutationType.map(mkRootDef("mutation")), - schema.subscriptionType.map(mkRootDef("subscription")) - ).flatten + List( + schema.mutationType.map(mkRootDef("mutation")), + schema.subscriptionType.map(mkRootDef("subscription")) + ).flatten val schemaDefn = if (fields.size == 1 && schema.queryType =:= schema.ref("Query")) "" else fields.mkString("schema {\n ", "\n ", "\n}\n") schemaDefn ++ - schema.types.map(renderTypeDefn).mkString("\n") + schema.types.map(renderTypeDefn).mkString("\n") } def renderTypeDefn(tpe: NamedType): String = { @@ -924,14 +1017,14 @@ object SchemaRenderer { s"""scalar $nme""" case ObjectType(nme, _, fields, ifs0) => - val ifs = if(ifs0.isEmpty) "" else " implements "+ifs0.map(_.name).mkString("&") + val ifs = if (ifs0.isEmpty) "" else " implements " + ifs0.map(_.name).mkString("&") s"""|type $nme$ifs { | ${fields.map(renderField).mkString("\n ")} |}""".stripMargin case InterfaceType(nme, _, fields, ifs0) => - val ifs = if(ifs0.isEmpty) "" else " implements "+ifs0.map(_.name).mkString("&") + val ifs = if (ifs0.isEmpty) "" else " implements " + ifs0.map(_.name).mkString("&") s"""|interface $nme$ifs { | ${fields.map(renderField).mkString("\n ")} @@ -958,9 +1051,9 @@ object SchemaRenderer { tpe match { case NullableType(tpe) => loop(tpe, true) - case ListType(tpe) => wrap(s"[${loop(tpe, false)}]") - case nt: NamedType => wrap(nt.name) - case NoType => "NoType" + case ListType(tpe) => wrap(s"[${loop(tpe, false)}]") + case nt: NamedType => wrap(nt.name) + case NoType => "NoType" } } @@ -987,9 +1080,9 @@ object SchemaRenderer { case TypedEnumValue(e) => e.name case ListValue(elems) => elems.map(renderValue).mkString("[", ", ", "]") case ObjectValue(fields) => - fields.map { - case (name, value) => s"$name : ${renderValue(value)}" - }.mkString("{", ", ", "}") + fields.map { + case (name, value) => s"$name : ${renderValue(value)}" + }.mkString("{", ", ", "}") case _ => "null" } diff --git a/modules/core/src/test/scala/schema/SchemaSpec.scala b/modules/core/src/test/scala/schema/SchemaSpec.scala new file mode 100644 index 00000000..d23dd5b4 --- /dev/null +++ b/modules/core/src/test/scala/schema/SchemaSpec.scala @@ -0,0 +1,80 @@ +// Copyright (c) 2016-2020 Association of Universities for Research in Astronomy, Inc. (AURA) +// For license information see LICENSE or https://opensource.org/licenses/BSD-3-Clause + +package schema + +import cats.data.Ior +import cats.data.Ior.Both +import cats.tests.CatsSuite +import edu.gemini.grackle.Schema + +final class SchemaSpec extends CatsSuite { + test("schema validation: undefined types: typo in the use of a Query result type") { + val schema = + Schema( + """ + type Query { + episodeById(id: String!): Episod + } + + type Episode { + id: String! + } + """ + ) + + schema match { + case Ior.Left(e) => assert(e.head.\\("message").head.asString.get == "Reference to undefined type: Episod") + case Both(a, b) => + assert(a.head.\\("message").head.asString.get == "Reference to undefined type: Episod") + assert(b.types.map(_.name) == List("Query", "Episode")) + case Ior.Right(b) => fail(s"Shouldn't compile: $b") + } + } + + test("schema validation: undefined types: typo in the use of an InputValueDefinition") { + val schema = Schema( + """ + type Query { + episodeById(id: CCid!): Episode + } + + scalar CCId + + type Episode { + id: CCId! + } + """ + ) + + schema match { + case Both(a, b) => + assert(a.head.\\("message").head.asString.get == "Reference to undefined type: CCid") + assert(b.types.map(_.name) == List("Query", "CCId", "Episode")) + case unexpected => fail(s"This was unexpected: $unexpected") + } + } + + test("schema validation: multiply-defined types") { + val schema = Schema( + """ + type Query { + episodeById(id: String!): Episode + } + + type Episode { + id: String! + } + + type Episode { + episodeId: String! + } + """ + ) + + schema match { + case Ior.Left(e) => assert(e.head.\\("message").head.asString.get == "Duplicate NamedType found: Episode") + case unexpected => fail(s"This was unexpected: $unexpected") + } + } +}