diff --git a/build.sbt b/build.sbt index 0ce2a191..dc72bc5f 100644 --- a/build.sbt +++ b/build.sbt @@ -94,8 +94,27 @@ lazy val core = project description := "Scala GraphQL implementation", mimaPreviousArtifacts := Set("org.sangria-graphql" %% "sangria-core" % "4.0.0"), mimaBinaryIssueFilters ++= Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.execute"), + ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.prepare"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "sangria.validation.RuleBasedQueryValidator.this"), + "sangria.execution.QueryReducerExecutor.reduceQueryWithoutVariables"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.execution.ValueCoercionHelper.isValidValue"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.execution.ValueCoercionHelper.getVariableValue"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.execution.batch.BatchExecutor.executeBatch"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.schema.ResolverBasedAstSchemaBuilder.validateSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.validation.QueryValidator.validateQuery"), + ProblemFilters.exclude[ReversedMissingMethodProblem]( + "sangria.validation.QueryValidator.validateQuery"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.validation.RuleBasedQueryValidator.validateQuery"), ProblemFilters.exclude[DirectMissingMethodProblem]( "sangria.validation.ValidationContext.this") ), diff --git a/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala b/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala index 07b0b558..f0169695 100644 --- a/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala +++ b/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala @@ -11,7 +11,7 @@ import sangria.validation.{QueryValidator, RuleBasedQueryValidator, Violation} @State(Scope.Thread) class OverlappingFieldsCanBeMergedBenchmark { - val validator: QueryValidator = RuleBasedQueryValidator( + val validator: QueryValidator = new RuleBasedQueryValidator( List(new rules.OverlappingFieldsCanBeMerged)) val schema: Schema[_, _] = @@ -98,7 +98,7 @@ class OverlappingFieldsCanBeMergedBenchmark { bh.consume(doValidate(validator, deepAbstractConcrete)) private def doValidate(validator: QueryValidator, document: Document): Vector[Violation] = { - val result = validator.validateQuery(schema, document) + val result = validator.validateQuery(schema, document, None) require(result.isEmpty) result } diff --git a/modules/core/src/main/scala/sangria/execution/Executor.scala b/modules/core/src/main/scala/sangria/execution/Executor.scala index e1bf053b..4452ced8 100644 --- a/modules/core/src/main/scala/sangria/execution/Executor.scala +++ b/modules/core/src/main/scala/sangria/execution/Executor.scala @@ -19,7 +19,8 @@ case class Executor[Ctx, Root]( deprecationTracker: DeprecationTracker = DeprecationTracker.empty, middleware: List[Middleware[Ctx]] = Nil, maxQueryDepth: Option[Int] = None, - queryReducers: List[QueryReducer[Ctx, _]] = Nil + queryReducers: List[QueryReducer[Ctx, _]] = Nil, + errorsLimit: Option[Int] = None )(implicit executionContext: ExecutionContext) { def prepare[Input]( queryAst: ast.Document, @@ -29,7 +30,7 @@ case class Executor[Ctx, Root]( variables: Input = emptyMapVars )(implicit um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] = { val (violations, validationTiming) = - TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst)) + TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit)) if (violations.nonEmpty) Future.failed(ValidationError(violations, exceptionHandler)) @@ -49,7 +50,9 @@ case class Executor[Ctx, Root]( operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) unmarshalledVariables <- valueCollector.getVariableValues( operation.variables, - scalarMiddleware) + scalarMiddleware, + errorsLimit + ) fieldCollector = new FieldCollector[Ctx, Root]( schema, queryAst, @@ -141,7 +144,7 @@ case class Executor[Ctx, Root]( um: InputUnmarshaller[Input], scheme: ExecutionScheme): scheme.Result[Ctx, marshaller.Node] = { val (violations, validationTiming) = - TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst)) + TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit)) if (violations.nonEmpty) scheme.failed(ValidationError(violations, exceptionHandler)) @@ -161,7 +164,9 @@ case class Executor[Ctx, Root]( operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) unmarshalledVariables <- valueCollector.getVariableValues( operation.variables, - scalarMiddleware) + scalarMiddleware, + errorsLimit + ) fieldCollector = new FieldCollector[Ctx, Root]( schema, queryAst, @@ -324,7 +329,8 @@ object Executor { deprecationTracker: DeprecationTracker = DeprecationTracker.empty, middleware: List[Middleware[Ctx]] = Nil, maxQueryDepth: Option[Int] = None, - queryReducers: List[QueryReducer[Ctx, _]] = Nil + queryReducers: List[QueryReducer[Ctx, _]] = Nil, + errorsLimit: Option[Int] = None )(implicit executionContext: ExecutionContext, marshaller: ResultMarshaller, @@ -338,7 +344,8 @@ object Executor { deprecationTracker, middleware, maxQueryDepth, - queryReducers) + queryReducers, + errorsLimit) .execute(queryAst, userContext, root, operationName, variables) def prepare[Ctx, Root, Input]( @@ -354,7 +361,8 @@ object Executor { deprecationTracker: DeprecationTracker = DeprecationTracker.empty, middleware: List[Middleware[Ctx]] = Nil, maxQueryDepth: Option[Int] = None, - queryReducers: List[QueryReducer[Ctx, _]] = Nil + queryReducers: List[QueryReducer[Ctx, _]] = Nil, + errorsLimit: Option[Int] = None )(implicit executionContext: ExecutionContext, um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] = @@ -366,7 +374,8 @@ object Executor { deprecationTracker, middleware, maxQueryDepth, - queryReducers) + queryReducers, + errorsLimit) .prepare(queryAst, userContext, root, operationName, variables) def getOperationRootType[Ctx, Root]( diff --git a/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala b/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala index c086c064..d82ccefa 100644 --- a/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala +++ b/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala @@ -19,9 +19,10 @@ object QueryReducerExecutor { queryValidator: QueryValidator = QueryValidator.default, exceptionHandler: ExceptionHandler = ExceptionHandler.empty, deprecationTracker: DeprecationTracker = DeprecationTracker.empty, - middleware: List[Middleware[Ctx]] = Nil + middleware: List[Middleware[Ctx]] = Nil, + errorsLimit: Option[Int] = None )(implicit executionContext: ExecutionContext): Future[(Ctx, TimeMeasurement)] = { - val violations = queryValidator.validateQuery(schema, queryAst) + val violations = queryValidator.validateQuery(schema, queryAst, errorsLimit) if (violations.nonEmpty) Future.failed(ValidationError(violations, exceptionHandler)) diff --git a/modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala b/modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala index 289ece6e..d5758f2b 100644 --- a/modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala +++ b/modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala @@ -580,107 +580,128 @@ class ValueCoercionHelper[Ctx]( nodeLocation.toList ++ firstValue.toList } - def isValidValue[In](tpe: InputType[_], input: Option[In])(implicit - um: InputUnmarshaller[In]): Vector[Violation] = (tpe, input) match { - case (OptionInputType(ofType), Some(value)) if um.isDefined(value) => - isValidValue(ofType, Some(value)) - case (OptionInputType(_), _) => Vector.empty - case (_, None) => Vector(NotNullValueIsNullViolation(sourceMapper, Nil)) - - case (ListInputType(ofType), Some(values)) if um.isListNode(values) => - um.getListValue(values) - .toVector - .flatMap(v => - isValidValue( - ofType, - v match { - case opt: Option[In @unchecked] => opt - case other => Option(other) - }).map(ListValueViolation(0, _, sourceMapper, Nil))) - - case (ListInputType(ofType), Some(value)) if um.isDefined(value) => - isValidValue( - ofType, - value match { - case opt: Option[In @unchecked] => opt - case other => Option(other) - }).map(ListValueViolation(0, _, sourceMapper, Nil)) - - case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) => - val unknownFields = um.getMapKeys(valueMap).toVector.collect { - case f if !objTpe.fieldsByName.contains(f) => - UnknownInputObjectFieldViolation( - SchemaRenderer.renderTypeName(objTpe, true), - f, - sourceMapper, - Nil) - } + private def isValidValue[In]( + inputType: InputType[_], + input: Option[In], + errorsLimit: Option[Int])(implicit um: InputUnmarshaller[In]): Vector[Violation] = { - val fieldViolations = - objTpe.fields.toVector.flatMap(f => - isValidValue(f.fieldType, um.getMapValue(valueMap, f.name)) - .map(MapValueViolation(f.name, _, sourceMapper, Nil))) + // keeping track of the number of errors + var errors = 0 + def addViolation(violation: Violation): Vector[Violation] = { + errors += 1 + Vector(violation) + } - fieldViolations ++ unknownFields + def isValidValueRec(tpe: InputType[_], in: Option[In])(implicit + um: InputUnmarshaller[In]): Vector[Violation] = + // early termination if errors limit is defined and the current number of violations exceeds the limit + if (errorsLimit.exists(_ <= errors)) Vector.empty + else + (tpe, in) match { + case (OptionInputType(ofType), Some(value)) if um.isDefined(value) => + isValidValueRec(ofType, Some(value)) + case (OptionInputType(_), _) => Vector.empty + case (_, None) => addViolation(NotNullValueIsNullViolation(sourceMapper, Nil)) + + case (ListInputType(ofType), Some(values)) if um.isListNode(values) => + um.getListValue(values) + .toVector + .flatMap(v => + isValidValueRec( + ofType, + v match { + case opt: Option[In @unchecked] => opt + case other => Option(other) + }).map(ListValueViolation(0, _, sourceMapper, Nil))) + + case (ListInputType(ofType), Some(value)) if um.isDefined(value) => + isValidValueRec( + ofType, + value match { + case opt: Option[In @unchecked] => opt + case other => Option(other) + }).map(ListValueViolation(0, _, sourceMapper, Nil)) + + case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) => + val unknownFields = um.getMapKeys(valueMap).toVector.collect { + case f if !objTpe.fieldsByName.contains(f) => + addViolation( + UnknownInputObjectFieldViolation( + SchemaRenderer.renderTypeName(objTpe, true), + f, + sourceMapper, + Nil)).head + } - case (objTpe: InputObjectType[_], _) => - Vector( - InputObjectIsOfWrongTypeMissingViolation( - SchemaRenderer.renderTypeName(objTpe, true), - sourceMapper, - Nil)) + val fieldViolations = + objTpe.fields.toVector.flatMap(f => + isValidValueRec(f.fieldType, um.getMapValue(valueMap, f.name)) + .map(MapValueViolation(f.name, _, sourceMapper, Nil))) - case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) => - val coerced = um.getScalarValue(value) match { - case node: ast.Value => scalar.coerceInput(node) - case other => scalar.coerceUserInput(other) - } + fieldViolations ++ unknownFields - coerced match { - case Left(violation) => Vector(violation) - case _ => Vector.empty - } + case (objTpe: InputObjectType[_], _) => + addViolation( + InputObjectIsOfWrongTypeMissingViolation( + SchemaRenderer.renderTypeName(objTpe, true), + sourceMapper, + Nil)) - case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) => - val coerced = um.getScalarValue(value) match { - case node: ast.Value => scalar.aliasFor.coerceInput(node) - case other => scalar.aliasFor.coerceUserInput(other) - } + case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) => + val coerced = um.getScalarValue(value) match { + case node: ast.Value => scalar.coerceInput(node) + case other => scalar.coerceUserInput(other) + } - coerced match { - case Left(violation) => Vector(violation) - case Right(v) => - scalar.fromScalar(v) match { - case Left(violation) => Vector(violation) - case _ => Vector.empty - } - } + coerced match { + case Left(violation) => addViolation(violation) + case _ => Vector.empty + } - case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) => - val coerced = um.getScalarValue(value) match { - case node: ast.Value => enumT.coerceInput(node) - case other => enumT.coerceUserInput(other) - } + case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) => + val coerced = um.getScalarValue(value) match { + case node: ast.Value => scalar.aliasFor.coerceInput(node) + case other => scalar.aliasFor.coerceUserInput(other) + } - coerced match { - case Left(violation) => Vector(violation) - case _ => Vector.empty - } + coerced match { + case Left(violation) => addViolation(violation) + case Right(v) => + scalar.fromScalar(v) match { + case Left(violation) => addViolation(violation) + case _ => Vector.empty + } + } - case (enumT: EnumType[_], Some(value)) => - Vector(EnumCoercionViolation) + case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) => + val coerced = um.getScalarValue(value) match { + case node: ast.Value => enumT.coerceInput(node) + case other => enumT.coerceUserInput(other) + } + + coerced match { + case Left(violation) => addViolation(violation) + case _ => Vector.empty + } + + case (enumT: EnumType[_], Some(value)) => + addViolation(EnumCoercionViolation) - case _ => - Vector(GenericInvalidValueViolation(sourceMapper, Nil)) + case _ => + addViolation(GenericInvalidValueViolation(sourceMapper, Nil)) + } + + isValidValueRec(inputType, input) } def getVariableValue[In]( definition: ast.VariableDefinition, tpe: InputType[_], input: Option[In], - fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]])(implicit - um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = { - val violations = isValidValue(tpe, input) + fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]], + errorsLimit: Option[Int] + )(implicit um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = { + val violations = isValidValue(tpe, input, errorsLimit) if (violations.isEmpty) { val fieldPath = s"$$${definition.name}" :: Nil diff --git a/modules/core/src/main/scala/sangria/execution/ValueCollector.scala b/modules/core/src/main/scala/sangria/execution/ValueCollector.scala index 9aadc755..f1e7a428 100644 --- a/modules/core/src/main/scala/sangria/execution/ValueCollector.scala +++ b/modules/core/src/main/scala/sangria/execution/ValueCollector.scala @@ -28,8 +28,14 @@ class ValueCollector[Ctx, Input]( def getVariableValues( definitions: Vector[ast.VariableDefinition], - fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]]) - : Try[Map[String, VariableValue]] = + fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]] + ): Try[Map[String, VariableValue]] = getVariableValues(definitions, fromScalarMiddleware, None) + + def getVariableValues( + definitions: Vector[ast.VariableDefinition], + fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]], + errorsLimit: Option[Int] + ): Try[Map[String, VariableValue]] = if (!um.isMapNode(inputVars)) Failure( new ExecutionError( @@ -37,33 +43,51 @@ class ValueCollector[Ctx, Input]( exceptionHandler)) else { val res = - definitions.foldLeft(Vector.empty[(String, Either[Vector[Violation], VariableValue])]) { + definitions.foldLeft( + (0, Vector.empty[(String, Either[Vector[Violation], VariableValue])])) { case (acc, varDef) => - val value = schema - .getInputType(varDef.tpe) - .map( - coercionHelper.getVariableValue( + val (accErrors, accResult) = acc + + // early termination if errors limit is defined and the current number of violations exceeds the limit + if (errorsLimit.exists(_ <= accErrors)) acc + else { + val value = schema + .getInputType(varDef.tpe) + .map(coercionHelper.getVariableValue( varDef, _, um.getRootMapValue(inputVars, varDef.name), - fromScalarMiddleware)) - .getOrElse( - Left( - Vector( - UnknownVariableTypeViolation( - varDef.name, - QueryRenderer.render(varDef.tpe), - sourceMapper, - varDef.location.toList)))) - - value match { - case Right(Some(v)) => acc :+ (varDef.name -> Right(v)) - case Right(None) => acc - case Left(violations) => acc :+ (varDef.name -> Left(violations)) + fromScalarMiddleware, + // calculate the allowed number of errors to be returned (if any) + errorsLimit.map(_ - accErrors) + )) + .getOrElse( + Left( + Vector( + UnknownVariableTypeViolation( + varDef.name, + QueryRenderer.render(varDef.tpe), + sourceMapper, + varDef.location.toList)))) + + value match { + case Right(Some(v)) => (accErrors, accResult :+ (varDef.name -> Right(v))) + case Right(None) => acc + case Left(violations) => + // number of errors that is allowed to use (all if errors limit is not defined) + val errorsLeftToUse = errorsLimit.fold(violations.length) { limit => + Math.min(violations.length, limit - accErrors) + } + + ( + accErrors + errorsLeftToUse, + accResult :+ (varDef.name -> Left(violations.take(errorsLeftToUse))) + ) + } } } - val (errors, values) = res.partition(_._2.isLeft) + val (errors, values) = res._2.partition(_._2.isLeft) if (errors.nonEmpty) Failure( diff --git a/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala b/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala index f59d49cd..17de05f9 100644 --- a/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala +++ b/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala @@ -75,7 +75,8 @@ object BatchExecutor { middleware: List[Middleware[Ctx]] = Nil, maxQueryDepth: Option[Int] = None, queryReducers: List[QueryReducer[Ctx, _]] = Nil, - inferVariableDefinitions: Boolean = true + inferVariableDefinitions: Boolean = true, + errorsLimit: Option[Int] = None )(implicit executionContext: ExecutionContext, marshaller: SymmetricMarshaller[T], @@ -100,7 +101,7 @@ object BatchExecutor { inferVariableDefinitions, exceptionHandler)) .flatMap { case res @ (updatedDocument, _) => - val violations = queryValidator.validateQuery(schema, updatedDocument) + val violations = queryValidator.validateQuery(schema, updatedDocument, errorsLimit) if (violations.nonEmpty) Failure(ValidationError(violations, exceptionHandler)) else Success(res) diff --git a/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala b/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala index d738f362..7d711e68 100644 --- a/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala +++ b/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala @@ -61,8 +61,9 @@ class ResolverBasedAstSchemaBuilder[Ctx](val resolvers: Seq[AstSchemaResolver[Ct def validateSchema( schema: ast.Document, - validator: QueryValidator = ResolverBasedAstSchemaBuilder.validator): Vector[Violation] = - allowKnownDynamicDirectives(validator.validateQuery(validationSchema, schema)) + validator: QueryValidator = ResolverBasedAstSchemaBuilder.validator, + errorsLimit: Option[Int] = None): Vector[Violation] = + allowKnownDynamicDirectives(validator.validateQuery(validationSchema, schema, errorsLimit)) def validateSchemaWithException( schema: ast.Document, diff --git a/modules/core/src/main/scala/sangria/validation/QueryValidator.scala b/modules/core/src/main/scala/sangria/validation/QueryValidator.scala index 1493d6ef..b3ff0ca6 100644 --- a/modules/core/src/main/scala/sangria/validation/QueryValidator.scala +++ b/modules/core/src/main/scala/sangria/validation/QueryValidator.scala @@ -11,7 +11,10 @@ import scala.collection.mutable.{ListBuffer, Map => MutableMap, Set => MutableSe import scala.reflect.{ClassTag, classTag} trait QueryValidator { - def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation] + def validateQuery( + schema: Schema[_, _], + queryAst: ast.Document, + errorsLimit: Option[Int]): Vector[Violation] } object QueryValidator { @@ -45,25 +48,24 @@ object QueryValidator { new SingleFieldSubscriptions ) - @deprecated("use ruleBased setting 'errorsLimit' instead", "4.0.1") def ruleBased(rules: List[ValidationRule]): RuleBasedQueryValidator = - RuleBasedQueryValidator(rules) - def ruleBased(rules: List[ValidationRule], errorsLimit: Option[Int]): RuleBasedQueryValidator = - new RuleBasedQueryValidator(rules, errorsLimit) + new RuleBasedQueryValidator(rules) val empty: QueryValidator = new QueryValidator { - def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation] = - Vector.empty + def validateQuery( + schema: Schema[_, _], + queryAst: ast.Document, + errorsLimit: Option[Int]): Vector[Violation] = Vector.empty } - val default: RuleBasedQueryValidator = ruleBased(allRules, errorsLimit = Some(10)) + val default: RuleBasedQueryValidator = ruleBased(allRules) } -class RuleBasedQueryValidator( - rules: List[ValidationRule], - errorsLimit: Option[Int] -) extends QueryValidator { - def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation] = { +class RuleBasedQueryValidator(rules: List[ValidationRule]) extends QueryValidator { + def validateQuery( + schema: Schema[_, _], + queryAst: ast.Document, + errorsLimit: Option[Int]): Vector[Violation] = { val ctx = new ValidationContext( schema, queryAst, @@ -152,15 +154,10 @@ class RuleBasedQueryValidator( val cls = classTag[T].runtimeClass val newRules = rules.filterNot(r => cls.isAssignableFrom(r.getClass)) - RuleBasedQueryValidator(newRules) + new RuleBasedQueryValidator(newRules) } } -object RuleBasedQueryValidator { - def apply(rules: List[ValidationRule]): RuleBasedQueryValidator = - new RuleBasedQueryValidator(rules, None) -} - class ValidationContext( val schema: Schema[_, _], val doc: ast.Document, diff --git a/modules/core/src/test/scala/sangria/execution/VariablesSpec.scala b/modules/core/src/test/scala/sangria/execution/VariablesSpec.scala index 4ecc8f9a..9c18471d 100644 --- a/modules/core/src/test/scala/sangria/execution/VariablesSpec.scala +++ b/modules/core/src/test/scala/sangria/execution/VariablesSpec.scala @@ -922,5 +922,54 @@ class VariablesSpec extends AnyWordSpec with Matchers with GraphQlSupport { validateQuery = false ) } + + "Execute: Limits number of errors" should { + val query = + """ + query noErrorLimits($input1: [String!]!, $input2: [String!]!, $input3: [String!]!) { + result1: nnListNN(input: $input1) + result2: nnListNN(input: $input2) + result3: nnListNN(input: $input3) + } + """ + val variables = + """{"input1": ["A",null,"B",null], "input2": [null,"C",null], "input3": [null,null,"D"]}""".parseJson + + "return all errors if no limit is provided" in checkContainsErrors( + (), + query, + None, + List( + """Variable '$input1' expected value of type '[String!]!' but got: ["A",null,"B",null].""" -> + List(Pos(2, 31)), + """Variable '$input1' expected value of type '[String!]!' but got: ["A",null,"B",null].""" -> + List(Pos(2, 31)), + """Variable '$input2' expected value of type '[String!]!' but got: [null,"C",null].""" -> + List(Pos(2, 52)), + """Variable '$input2' expected value of type '[String!]!' but got: [null,"C",null].""" -> + List(Pos(2, 52)), + """Variable '$input3' expected value of type '[String!]!' but got: [null,null,"D"].""" -> + List(Pos(2, 73)), + """Variable '$input3' expected value of type '[String!]!' but got: [null,null,"D"].""" -> + List(Pos(2, 73)) + ), + variables + ) + "return errors up to the limit" in checkContainsErrors( + (), + query, + None, + List( + """Variable '$input1' expected value of type '[String!]!' but got: ["A",null,"B",null].""" -> + List(Pos(2, 31)), + """Variable '$input1' expected value of type '[String!]!' but got: ["A",null,"B",null].""" -> + List(Pos(2, 31)), + """Variable '$input2' expected value of type '[String!]!' but got: [null,"C",null].""" -> + List(Pos(2, 52)) + ), + variables, + errorsLimit = Some(3) + ) + } } } diff --git a/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala b/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala index e617fc57..203db9ce 100644 --- a/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala +++ b/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala @@ -30,7 +30,7 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) } "Notes that non-existent fields are invalid" in { @@ -42,7 +42,7 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query) should have size 1 + QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 } "Requires fields on objects" in { @@ -52,7 +52,7 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query) should have size 1 + QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 } "Disallows fields on scalars" in { @@ -66,7 +66,7 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query) should have size 1 + QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 } "Disallows object fields on interfaces" in { @@ -79,7 +79,7 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query) should have size 1 + QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 } "Allows object fields in fragments" in { @@ -96,7 +96,7 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) } "Allows object fields in inline fragments" in { @@ -111,7 +111,7 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) } } } diff --git a/modules/core/src/test/scala/sangria/util/CatsSupport.scala b/modules/core/src/test/scala/sangria/util/CatsSupport.scala index e4ff2f1b..60d56f92 100644 --- a/modules/core/src/test/scala/sangria/util/CatsSupport.scala +++ b/modules/core/src/test/scala/sangria/util/CatsSupport.scala @@ -220,8 +220,8 @@ object CatsScenarioExecutor extends FutureResultSupport { case Validate(rules) => ValidationResult( - RuleBasedQueryValidator(rules.toList) - .validateQuery(`given`.schema, QueryParser.parse(`given`.query).get)) + new RuleBasedQueryValidator(rules.toList) + .validateQuery(`given`.schema, QueryParser.parse(`given`.query).get, None)) case Execute(validate, value, vars, op) => val validator = if (validate) QueryValidator.default else QueryValidator.empty diff --git a/modules/core/src/test/scala/sangria/util/GraphQlSupport.scala b/modules/core/src/test/scala/sangria/util/GraphQlSupport.scala index bc152d5b..89452c88 100644 --- a/modules/core/src/test/scala/sangria/util/GraphQlSupport.scala +++ b/modules/core/src/test/scala/sangria/util/GraphQlSupport.scala @@ -22,7 +22,8 @@ object SimpleGraphQlSupport extends FutureResultSupport with Matchers { args: A, userContext: Any = (), resolver: DeferredResolver[Any] = DeferredResolver.empty, - validateQuery: Boolean = true) = { + validateQuery: Boolean = true, + errorsLimit: Option[Int] = None) = { val Success(doc) = QueryParser.parse(query) val exceptionHandler = ExceptionHandler { case (m, e) => @@ -38,7 +39,8 @@ object SimpleGraphQlSupport extends FutureResultSupport with Matchers { variables = args, exceptionHandler = exceptionHandler, queryValidator = if (validateQuery) QueryValidator.default else QueryValidator.empty, - deferredResolver = resolver + deferredResolver = resolver, + errorsLimit = errorsLimit ) .awaitAndRecoverQueryAnalysisScala } @@ -99,7 +101,8 @@ object SimpleGraphQlSupport extends FutureResultSupport with Matchers { args: JsValue = JsObject.empty, userContext: Any = (), resolver: DeferredResolver[_] = DeferredResolver.empty, - validateQuery: Boolean = true + validateQuery: Boolean = true, + errorsLimit: Option[Int] = None ): Unit = { val result = executeTestQuery( schema, @@ -108,7 +111,9 @@ object SimpleGraphQlSupport extends FutureResultSupport with Matchers { args, validateQuery = validateQuery, userContext = userContext, - resolver = resolver.asInstanceOf[DeferredResolver[Any]]).asInstanceOf[Map[String, Any]] + resolver = resolver.asInstanceOf[DeferredResolver[Any]], + errorsLimit = errorsLimit + ).asInstanceOf[Map[String, Any]] result.get("data") should be(expectedData) @@ -257,7 +262,8 @@ trait GraphQlSupport extends FutureResultSupport with Matchers { expectedData: Option[Map[String, Any]], expectedErrorStrings: Seq[(String, Seq[Pos])], args: JsValue = JsObject.empty, - validateQuery: Boolean = true): Unit = + validateQuery: Boolean = true, + errorsLimit: Option[Int] = None): Unit = SimpleGraphQlSupport.checkContainsErrors( schema, data, @@ -265,7 +271,8 @@ trait GraphQlSupport extends FutureResultSupport with Matchers { expectedData, expectedErrorStrings, args = args, - validateQuery = validateQuery) + validateQuery = validateQuery, + errorsLimit = errorsLimit) } case class Pos(line: Int, col: Int) diff --git a/modules/core/src/test/scala/sangria/util/ValidationSupport.scala b/modules/core/src/test/scala/sangria/util/ValidationSupport.scala index dfd977bf..7de201fd 100644 --- a/modules/core/src/test/scala/sangria/util/ValidationSupport.scala +++ b/modules/core/src/test/scala/sangria/util/ValidationSupport.scala @@ -351,7 +351,7 @@ trait ValidationSupport extends Matchers { expectedErrors: Seq[(String, Seq[Pos])]) = { val Success(doc) = QueryParser.parse(query) - assertViolations(validator(rules).validateQuery(s, doc), expectedErrors: _*) + assertViolations(validator(rules).validateQuery(s, doc, None), expectedErrors: _*) } def expectInputInvalid( @@ -367,7 +367,7 @@ trait ValidationSupport extends Matchers { def expectValid(s: Schema[_, _], rules: List[ValidationRule], query: String) = { val Success(doc) = QueryParser.parse(query) - val errors = validator(rules).validateQuery(s, doc) + val errors = validator(rules).validateQuery(s, doc, None) withClue(renderViolations(errors)) { errors should have size 0 @@ -432,7 +432,7 @@ trait ValidationSupport extends Matchers { violationCheck: Violation => Unit): Unit = { val schema = Schema.buildFromAst(initialSchemaDoc) val Success(docUnderTest) = QueryParser.parse(sdlUnderTest) - val violations = validator(v.toList).validateQuery(schema, docUnderTest) + val violations = validator(v.toList).validateQuery(schema, docUnderTest, None) violations shouldNot be(empty) violations.size shouldBe 1 violationCheck(violations.head) @@ -451,9 +451,9 @@ trait ValidationSupport extends Matchers { v: Option[ValidationRule]): Unit = { val schema = Schema.buildFromAst(initialSchemaDoc) val Success(docUnderTest) = QueryParser.parse(sdlUnderTest) - val violations = validator(v.toList).validateQuery(schema, docUnderTest) + val violations = validator(v.toList).validateQuery(schema, docUnderTest, None) violations shouldBe empty } - def validator(rules: List[ValidationRule]) = RuleBasedQueryValidator(rules) + def validator(rules: List[ValidationRule]) = new RuleBasedQueryValidator(rules) } diff --git a/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala b/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala index 7c954052..d64e81ef 100644 --- a/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala +++ b/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala @@ -8,7 +8,7 @@ import scala.util.Success class QueryValidatorSpec extends AnyWordSpec { "QueryValidator" when { - val rules = QueryValidator.allRules + val validator = new RuleBasedQueryValidator(QueryValidator.allRules) "testing RuleBasedQueryValidator" should { val TestInputType = InputObjectType( @@ -41,20 +41,17 @@ class QueryValidatorSpec extends AnyWordSpec { """ "not limit number of errors returned if the limit is not provided" in { - val validator = RuleBasedQueryValidator(rules) - val Success(doc) = QueryParser.parse(invalidQuery) - val result = validator.validateQuery(schema, doc) + val result = validator.validateQuery(schema, doc, None) // 10 errors are expected because there are 5 input objects in the list with 2 missing fields each assertResult(10)(result.length) } "limit number of errors returned if the limit is provided" in { val errorsLimit = 5 - val validator = new RuleBasedQueryValidator(rules, Some(errorsLimit)) val Success(doc) = QueryParser.parse(invalidQuery) - val result = validator.validateQuery(schema, doc) + val result = validator.validateQuery(schema, doc, Some(errorsLimit)) assertResult(errorsLimit)(result.length) }