Skip to content

Commit

Permalink
Implement "Field Selection Merging" validation (ghostdogpr#1084)
Browse files Browse the repository at this point in the history
* feat: Validate field merging (5.3.2).

* remove import

* fix scala3

* Remove unused

* Print types in error

* wip: Use algorithm from Apollo

* uncached xing

* cleanup

* test: Add fragment tests from spec

* fix leaf case

* cleanup test names

* compare field names

* style: remove braces

* fix: memoize

* cleanup

* style: move stuff to utils

* fix: scala2

* only run once at top level

* zio-cached

* Cache plain scala function

This reverts commit 9a45e28.

* Add return type

* Pattern match on cache hit

* Use Chunks for better perf

* Use mutable map

* Pattern match

* use isObjectType

* Add helper

* Add benchmark

* Add new sangria

* rely on mutability

* fix empty match

* put -> update
  • Loading branch information
frekw authored Oct 15, 2021
1 parent 3b4ba1c commit 481983b
Show file tree
Hide file tree
Showing 9 changed files with 988 additions and 88 deletions.
177 changes: 171 additions & 6 deletions benchmarks/src/main/scala/caliban/GraphQLBenchmarks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,104 @@ class GraphQLBenchmarks {
}
"""

val fragmentsQuery = s"""
query IntrospectionQuery {
__schema {
queryType {
name
${"...on __Type { name }" * 100}
}
mutationType { name }
subscriptionType { name }
types {
...FullType
}
directives {
name
description
locations
args {
...InputValue
}
}
}
}

fragment FullType on __Type {
kind
name
description
fields(includeDeprecated: true) {
name
description
args {
...InputValue
}
type {
...TypeRef
}
isDeprecated
deprecationReason
}
inputFields {
...InputValue
}
interfaces {
...TypeRef
}
enumValues(includeDeprecated: true) {
name
description
isDeprecated
deprecationReason
}
possibleTypes {
...TypeRef
}
}

fragment InputValue on __InputValue {
name
description
type { ...TypeRef }
defaultValue
}

fragment TypeRef on __Type {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
${"...on __Type { kind name }" * 1000}
}
}
}
}
}
}
}
}
"""

val runtime: Runtime[ZEnv] = new BootstrapRuntime {
override val platform: Platform = Platform.benchmark
}
Expand Down Expand Up @@ -160,12 +258,19 @@ class GraphQLBenchmarks {
()
}

implicit val OriginEnum: EnumType[Origin] = deriveEnumType[Origin]()
implicit val CaptainType: ObjectType[Unit, Captain] = deriveObjectType[Unit, Captain]()
implicit val PilotType: ObjectType[Unit, Pilot] = deriveObjectType[Unit, Pilot]()
implicit val EngineerType: ObjectType[Unit, Engineer] = deriveObjectType[Unit, Engineer]()
implicit val MechanicType: ObjectType[Unit, Mechanic] = deriveObjectType[Unit, Mechanic]()
implicit val RoleType: UnionType[Unit] = UnionType(
@Benchmark
def fragmentsCaliban(): Unit = {
val io = interpreter.execute(fragmentsQuery)
runtime.unsafeRun(io)
()
}

implicit val OriginEnum: EnumType[Origin] = deriveEnumType[Origin]()
implicit val CaptainType: ObjectType[Unit, Captain] = deriveObjectType[Unit, Captain]()
implicit val PilotType: ObjectType[Unit, Pilot] = deriveObjectType[Unit, Pilot]()
implicit val EngineerType: ObjectType[Unit, Engineer] = deriveObjectType[Unit, Engineer]()
implicit val MechanicType: ObjectType[Unit, Mechanic] = deriveObjectType[Unit, Mechanic]()
implicit val RoleType: UnionType[Unit] = UnionType(
"Role",
types = List(PilotType, EngineerType, MechanicType, CaptainType)
)
Expand Down Expand Up @@ -236,4 +341,64 @@ class GraphQLBenchmarks {
()
}

object SangriaNewValidator {
import sangria.validation.RuleBasedQueryValidator
import sangria.validation.ValidationRule
import sangria.validation.rules._

val allRules =
new RuleBasedQueryValidator(
List(
new ValuesOfCorrectType,
new ExecutableDefinitions,
new FieldsOnCorrectType,
new FragmentsOnCompositeTypes,
new KnownArgumentNames,
new KnownDirectives,
new KnownFragmentNames,
new KnownTypeNames,
new LoneAnonymousOperation,
new NoFragmentCycles,
new NoUndefinedVariables,
new NoUnusedFragments,
new NoUnusedVariables,
//new OverlappingFieldsCanBeMerged,
new experimental.OverlappingFieldsCanBeMerged,
new PossibleFragmentSpreads,
new ProvidedRequiredArguments,
new ScalarLeafs,
new UniqueArgumentNames,
new UniqueDirectivesPerLocation,
new UniqueFragmentNames,
new UniqueInputFieldNames,
new UniqueOperationNames,
new UniqueVariableNames,
new VariablesAreInputTypes,
new VariablesInAllowedPosition,
new InputDocumentNonConflictingVariableInference,
new SingleFieldSubscriptions
)
)
}

@Benchmark
def fragmentsSangriaOld(): Unit = {
val future: Future[Json] =
Future
.fromTry(QueryParser.parse(fragmentsQuery))
.flatMap(queryAst => Executor.execute(schema, queryAst))
Await.result(future, 1 minute)
()
}

@Benchmark
def fragmentsSangriaNew(): Unit = {
val future: Future[Json] =
Future
.fromTry(QueryParser.parse(fragmentsQuery))
.flatMap(queryAst => Executor.execute(schema, queryAst, queryValidator = SangriaNewValidator.allRules))
Await.result(future, 1 minute)
()
}

}
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ lazy val benchmarks = project
.settings(
crossScalaVersions -= scala3,
libraryDependencies ++= Seq(
"org.sangria-graphql" %% "sangria" % "2.0.0",
"org.sangria-graphql" %% "sangria" % "2.1.3",
"org.sangria-graphql" %% "sangria-circe" % "1.3.0"
)
)
Expand Down
58 changes: 58 additions & 0 deletions core/src/main/scala/caliban/validation/FieldMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package caliban.validation

import caliban.introspection.adt._
import caliban.parsing.adt.Selection.{ Field, FragmentSpread, InlineFragment }
import caliban.parsing.adt._
import Utils._

object FieldMap {
val empty: FieldMap = Map.empty

implicit class FieldMapOps(val self: FieldMap) extends AnyVal {
def |+|(that: FieldMap): FieldMap =
(self.keySet ++ that.keySet).map { k =>
k -> (self.get(k).getOrElse(Set.empty) ++ that.get(k).getOrElse(Set.empty))
}.toMap

def show =
self.map { case (k, fields) =>
s"$k -> ${fields.map(_.fieldDef.name).mkString(", ")}"
}.mkString("\n")

def addField(
f: Field,
parentType: __Type,
selection: Field
): FieldMap = {
val responseName = f.alias.getOrElse(f.name)

getFields(parentType)
.flatMap(fields => fields.find(_.name == f.name))
.map { f =>
val sf = SelectedField(parentType, selection, f)
val entry = self.get(responseName).map(_ + sf).getOrElse(Set(sf))
self + (responseName -> entry)
}
.getOrElse(self)
}
}

def apply(context: Context, parentType: __Type, selectionSet: Iterable[Selection]): FieldMap =
selectionSet.foldLeft(FieldMap.empty)({ case (fields, selection) =>
selection match {
case FragmentSpread(name, directives) =>
context.fragments
.get(name)
.map { definition =>
val typ = getType(Some(definition.typeCondition), parentType, context)
apply(context, typ, definition.selectionSet) |+| fields
}
.getOrElse(fields)
case f: Field =>
fields.addField(f, parentType, f)
case InlineFragment(typeCondition, _, selectionSet) =>
val typ = getType(typeCondition, parentType, context)
apply(context, typ, selectionSet) |+| fields
}
})
}
132 changes: 132 additions & 0 deletions core/src/main/scala/caliban/validation/FragmentValidator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package caliban.validation

import caliban.CalibanError.ValidationError
import caliban.introspection.adt._
import caliban.parsing.adt.Selection
import zio.{ Chunk, IO, UIO }
import Utils._
import Utils.syntax._
import scala.collection.mutable

object FragmentValidator {
def findConflictsWithinSelectionSet(
context: Context,
parentType: __Type,
selectionSet: List[Selection]
): IO[ValidationError, Unit] = {
val shapeCache = scala.collection.mutable.Map.empty[Iterable[Selection], Chunk[String]]
val parentsCache = scala.collection.mutable.Map.empty[Iterable[Selection], Chunk[String]]
val groupsCache = scala.collection.mutable.Map.empty[Set[SelectedField], Chunk[Set[SelectedField]]]

def sameResponseShapeByName(context: Context, parentType: __Type, set: Iterable[Selection]): Chunk[String] =
shapeCache.get(set) match {
case Some(value) => value
case None =>
val fields = FieldMap(context, parentType, set)
val res = Chunk.fromIterable(fields.flatMap { case (name, values) =>
cross(values).flatMap { case (f1, f2) =>
if (doTypesConflict(f1.fieldDef.`type`(), f2.fieldDef.`type`())) {
Chunk(
s"$name has conflicting types: ${f1.parentType.name.getOrElse("")}.${f1.fieldDef.name} and ${f2.parentType.name
.getOrElse("")}.${f2.fieldDef.name}. Try using an alias."
)
} else
sameResponseShapeByName(context, parentType, f1.selection.selectionSet ++ f2.selection.selectionSet)
}
})
shapeCache.update(set, res)
res
}

def sameForCommonParentsByName(context: Context, parentType: __Type, set: Iterable[Selection]): Chunk[String] =
parentsCache.get(set) match {
case Some(value) => value
case None =>
val fields = FieldMap(context, parentType, set)
val res = Chunk.fromIterable(fields.flatMap({ case (name, fields) =>
groupByCommonParents(context, parentType, fields).flatMap { group =>
val merged = group.flatMap(_.selection.selectionSet)
requireSameNameAndArguments(group) ++ sameForCommonParentsByName(context, parentType, merged)
}
}))
parentsCache.update(set, res)
res
}

def doTypesConflict(t1: __Type, t2: __Type): Boolean =
if (isNonNull(t1))
if (isNonNull(t2)) (t1.ofType, t2.ofType).mapN((p1, p2) => doTypesConflict(p1, p2)).getOrElse(true)
else true
else if (isNonNull(t2))
true
else if (isListType(t1))
if (isListType(t2)) (t1.ofType, t2.ofType).mapN((p1, p2) => doTypesConflict(p1, p2)).getOrElse(true)
else true
else if (isListType(t2))
true
else if (isLeafType(t1) && isLeafType(t2)) {
t1.name != t2.name
} else if (!isComposite(t1) || !isComposite(t2))
true
else
false

def requireSameNameAndArguments(fields: Set[SelectedField]) =
cross(fields).flatMap { case (f1, f2) =>
if (f1.fieldDef.name != f2.fieldDef.name) {
List(
s"${f1.parentType.name.getOrElse("")}.${f1.fieldDef.name} and ${f2.parentType.name.getOrElse("")}.${f2.fieldDef.name} are different fields."
)
} else if (f1.selection.arguments != f2.selection.arguments)
List(s"${f1.fieldDef.name} and ${f2.fieldDef.name} have different arguments")
else List()
}

def groupByCommonParents(
context: Context,
parentType: __Type,
fields: Set[SelectedField]
): Chunk[Set[SelectedField]] =
groupsCache.get(fields) match {
case Some(value) => value
case None =>
val abstractGroup = fields.collect {
case field if !isConcrete(field.parentType) => field
}

val concreteGroups = mutable.Map.empty[String, Set[SelectedField]]

fields
.foreach({
case field @ SelectedField(
__Type(_, Some(name), _, _, _, _, _, _, _, _, _),
_,
_
) if isConcrete(field.parentType) =>
val value = concreteGroups.get(name).map(_ + field).getOrElse(Set(field))
concreteGroups.update(name, value)
case _ => ()
})

val res =
if (concreteGroups.size < 1) Chunk(fields)
else Chunk.fromIterable(concreteGroups.values.map(_ ++ abstractGroup))

groupsCache.update(fields, res)
res
}

val fields = FieldMap(
context,
parentType,
selectionSet
)

val conflicts = sameResponseShapeByName(context, parentType, selectionSet) ++
sameForCommonParentsByName(context, parentType, selectionSet)

IO.whenCase(conflicts) { case Chunk(head, _*) =>
IO.fail(ValidationError(head, ""))
}
}
}
Loading

0 comments on commit 481983b

Please sign in to comment.