From 43733dd48fa835edc4d562878bbd0825a11a9b16 Mon Sep 17 00:00:00 2001 From: Dima Pasieka Date: Thu, 29 Jun 2023 17:42:06 -0700 Subject: [PATCH] Limit number of errors returned (when variables are used) --- .../scala/sangria/execution/Executor.scala | 11 +- .../execution/InputDocumentMaterializer.scala | 4 +- .../execution/QueryReducerExecutor.scala | 4 +- .../execution/ValueCoercionHelper.scala | 161 ++++++++++-------- .../sangria/execution/ValueCollector.scala | 14 +- .../sangria/validation/QueryValidator.scala | 6 +- .../execution/ValueCoercionHelperSpec.scala | 4 +- 7 files changed, 119 insertions(+), 85 deletions(-) diff --git a/modules/core/src/main/scala/sangria/execution/Executor.scala b/modules/core/src/main/scala/sangria/execution/Executor.scala index e1bf053b..3b0555bc 100644 --- a/modules/core/src/main/scala/sangria/execution/Executor.scala +++ b/modules/core/src/main/scala/sangria/execution/Executor.scala @@ -4,9 +4,10 @@ import sangria.ast import sangria.ast.SourceMapper import sangria.marshalling.{InputUnmarshaller, ResultMarshaller} import sangria.schema._ -import sangria.validation.QueryValidator +import sangria.validation.{QueryValidator, RuleBasedQueryValidator} import InputUnmarshaller.emptyMapVars import sangria.execution.deferred.DeferredResolver + import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal import scala.util.{Failure, Success, Try} @@ -43,7 +44,9 @@ case class Executor[Ctx, Root]( userContext, exceptionHandler, scalarMiddleware, - false)(um) + false, + errorsLimit = queryValidator.errorsLimit + )(um) val executionResult = for { operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) @@ -155,7 +158,9 @@ case class Executor[Ctx, Root]( userContext, exceptionHandler, scalarMiddleware, - false)(um) + false, + errorsLimit = queryValidator.errorsLimit + )(um) val executionResult = for { operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) diff --git a/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala b/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala index 1e5e1fd2..2b70937e 100644 --- a/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala +++ b/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala @@ -26,7 +26,9 @@ case class InputDocumentMaterializer[Vars]( (), ExceptionHandler.empty, None, - false)(iu) + false, + errorsLimit = None + )(iu) val violations = QueryValidator.default.validateInputDocument(schema, document, inputType) diff --git a/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala b/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala index c086c064..236ceabf 100644 --- a/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala +++ b/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala @@ -35,7 +35,9 @@ object QueryReducerExecutor { userContext, exceptionHandler, scalarMiddleware, - true)(InputUnmarshaller.scalaInputUnmarshaller[Any @@ ScalaInput]) + true, + errorsLimit = queryValidator.errorsLimit + )(InputUnmarshaller.scalaInputUnmarshaller[Any @@ ScalaInput]) val executionResult = for { operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) diff --git a/modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala b/modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala index 289ece6e..9b391ccc 100644 --- a/modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala +++ b/modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala @@ -10,11 +10,13 @@ import sangria.util.Cache import sangria.validation._ import scala.collection.immutable.VectorBuilder +import scala.collection.mutable.ListBuffer class ValueCoercionHelper[Ctx]( sourceMapper: Option[SourceMapper] = None, deprecationTracker: DeprecationTracker = DeprecationTracker.empty, - userContext: Option[Ctx] = None) { + userContext: Option[Ctx] = None, + errorsLimit: Option[Int] = None) { import ValueCoercionHelper.defaultValueMapFn private def resolveListValue( @@ -580,98 +582,108 @@ class ValueCoercionHelper[Ctx]( nodeLocation.toList ++ firstValue.toList } - def isValidValue[In](tpe: InputType[_], input: Option[In])(implicit - um: InputUnmarshaller[In]): Vector[Violation] = (tpe, input) match { - case (OptionInputType(ofType), Some(value)) if um.isDefined(value) => - isValidValue(ofType, Some(value)) - case (OptionInputType(_), _) => Vector.empty - case (_, None) => Vector(NotNullValueIsNullViolation(sourceMapper, Nil)) - - case (ListInputType(ofType), Some(values)) if um.isListNode(values) => - um.getListValue(values) - .toVector - .flatMap(v => + def isValidValue[In](tpe: InputType[_], input: Option[In], errors: ListBuffer[Violation])(implicit + um: InputUnmarshaller[In]): Vector[Violation] = { + // TODO: NEED TO ADD VALUES TO `errors` to make it work properly + + if (errorsLimit.exists(_ <= errors.length)) Vector.empty + else + (tpe, input) match { + case (OptionInputType(ofType), Some(value)) if um.isDefined(value) => + isValidValue(ofType, Some(value), errors) + case (OptionInputType(_), _) => Vector.empty + case (_, None) => Vector(NotNullValueIsNullViolation(sourceMapper, Nil)) + + case (ListInputType(ofType), Some(values)) if um.isListNode(values) => + um.getListValue(values) + .toVector + .flatMap(v => + isValidValue( + ofType, + v match { + case opt: Option[In @unchecked] => opt + case other => Option(other) + }, + errors).map(ListValueViolation(0, _, sourceMapper, Nil))) + + case (ListInputType(ofType), Some(value)) if um.isDefined(value) => isValidValue( ofType, - v match { + value match { case opt: Option[In @unchecked] => opt case other => Option(other) - }).map(ListValueViolation(0, _, sourceMapper, Nil))) + }, + errors).map(ListValueViolation(0, _, sourceMapper, Nil)) + + case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) => + val unknownFields = um.getMapKeys(valueMap).toVector.collect { + case f if !objTpe.fieldsByName.contains(f) => + UnknownInputObjectFieldViolation( + SchemaRenderer.renderTypeName(objTpe, true), + f, + sourceMapper, + Nil) + } - case (ListInputType(ofType), Some(value)) if um.isDefined(value) => - isValidValue( - ofType, - value match { - case opt: Option[In @unchecked] => opt - case other => Option(other) - }).map(ListValueViolation(0, _, sourceMapper, Nil)) - - case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) => - val unknownFields = um.getMapKeys(valueMap).toVector.collect { - case f if !objTpe.fieldsByName.contains(f) => - UnknownInputObjectFieldViolation( - SchemaRenderer.renderTypeName(objTpe, true), - f, - sourceMapper, - Nil) - } + val fieldViolations = + objTpe.fields.toVector.flatMap(f => + isValidValue(f.fieldType, um.getMapValue(valueMap, f.name), errors) + .map(MapValueViolation(f.name, _, sourceMapper, Nil))) - val fieldViolations = - objTpe.fields.toVector.flatMap(f => - isValidValue(f.fieldType, um.getMapValue(valueMap, f.name)) - .map(MapValueViolation(f.name, _, sourceMapper, Nil))) + fieldViolations ++ unknownFields - fieldViolations ++ unknownFields + case (objTpe: InputObjectType[_], _) => + Vector( + InputObjectIsOfWrongTypeMissingViolation( + SchemaRenderer.renderTypeName(objTpe, true), + sourceMapper, + Nil)) - case (objTpe: InputObjectType[_], _) => - Vector( - InputObjectIsOfWrongTypeMissingViolation( - SchemaRenderer.renderTypeName(objTpe, true), - sourceMapper, - Nil)) + case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) => + val coerced = um.getScalarValue(value) match { + case node: ast.Value => scalar.coerceInput(node) + case other => scalar.coerceUserInput(other) + } - case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) => - val coerced = um.getScalarValue(value) match { - case node: ast.Value => scalar.coerceInput(node) - case other => scalar.coerceUserInput(other) - } + coerced match { + case Left(violation) => Vector(violation) + case _ => Vector.empty + } - coerced match { - case Left(violation) => Vector(violation) - case _ => Vector.empty - } + case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) => + val coerced = um.getScalarValue(value) match { + case node: ast.Value => scalar.aliasFor.coerceInput(node) + case other => scalar.aliasFor.coerceUserInput(other) + } - case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) => - val coerced = um.getScalarValue(value) match { - case node: ast.Value => scalar.aliasFor.coerceInput(node) - case other => scalar.aliasFor.coerceUserInput(other) - } + coerced match { + case Left(violation) => Vector(violation) + case Right(v) => + scalar.fromScalar(v) match { + case Left(violation) => Vector(violation) + case _ => Vector.empty + } + } + + case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) => + val coerced = um.getScalarValue(value) match { + case node: ast.Value => enumT.coerceInput(node) + case other => enumT.coerceUserInput(other) + } - coerced match { - case Left(violation) => Vector(violation) - case Right(v) => - scalar.fromScalar(v) match { + coerced match { case Left(violation) => Vector(violation) case _ => Vector.empty } - } - case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) => - val coerced = um.getScalarValue(value) match { - case node: ast.Value => enumT.coerceInput(node) - case other => enumT.coerceUserInput(other) - } + case (enumT: EnumType[_], Some(value)) => + Vector(EnumCoercionViolation) - coerced match { - case Left(violation) => Vector(violation) - case _ => Vector.empty + case _ => + Vector(GenericInvalidValueViolation(sourceMapper, Nil)) } - case (enumT: EnumType[_], Some(value)) => - Vector(EnumCoercionViolation) - case _ => - Vector(GenericInvalidValueViolation(sourceMapper, Nil)) } def getVariableValue[In]( @@ -680,7 +692,8 @@ class ValueCoercionHelper[Ctx]( input: Option[In], fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]])(implicit um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = { - val violations = isValidValue(tpe, input) + val errors = ListBuffer[Violation]() + val violations = isValidValue(tpe, input, errors) if (violations.isEmpty) { val fieldPath = s"$$${definition.name}" :: Nil diff --git a/modules/core/src/main/scala/sangria/execution/ValueCollector.scala b/modules/core/src/main/scala/sangria/execution/ValueCollector.scala index 9aadc755..499933a4 100644 --- a/modules/core/src/main/scala/sangria/execution/ValueCollector.scala +++ b/modules/core/src/main/scala/sangria/execution/ValueCollector.scala @@ -19,9 +19,10 @@ class ValueCollector[Ctx, Input]( userContext: Ctx, exceptionHandler: ExceptionHandler, fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]], - ignoreErrors: Boolean)(implicit um: InputUnmarshaller[Input]) { + ignoreErrors: Boolean, + errorsLimit: Option[Int])(implicit um: InputUnmarshaller[Input]) { val coercionHelper = - new ValueCoercionHelper[Ctx](sourceMapper, deprecationTracker, Some(userContext)) + new ValueCoercionHelper[Ctx](sourceMapper, deprecationTracker, Some(userContext), errorsLimit) private val argumentCache = Cache.empty[(ExecutionPath.PathCacheKeyReversed, Vector[ast.Argument]), Try[Args]] @@ -63,7 +64,14 @@ class ValueCollector[Ctx, Input]( } } - val (errors, values) = res.partition(_._2.isLeft) + val (allErrors, values) = res.partition(_._2.isLeft) + + // additionally reduce the number of errors + // although we limit the number of errors per input, it's still possible to exceed the limit if errors are combined. + val errors = errorsLimit match { + case Some(limit) => allErrors.take(limit) + case _ => allErrors + } if (errors.nonEmpty) Failure( diff --git a/modules/core/src/main/scala/sangria/validation/QueryValidator.scala b/modules/core/src/main/scala/sangria/validation/QueryValidator.scala index 1493d6ef..a5ba0536 100644 --- a/modules/core/src/main/scala/sangria/validation/QueryValidator.scala +++ b/modules/core/src/main/scala/sangria/validation/QueryValidator.scala @@ -12,6 +12,7 @@ import scala.reflect.{ClassTag, classTag} trait QueryValidator { def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation] + def errorsLimit: Option[Int] } object QueryValidator { @@ -54,14 +55,15 @@ object QueryValidator { val empty: QueryValidator = new QueryValidator { def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation] = Vector.empty + override val errorsLimit: Option[Int] = None } - val default: RuleBasedQueryValidator = ruleBased(allRules, errorsLimit = Some(10)) + val default: RuleBasedQueryValidator = ruleBased(allRules, errorsLimit = Some(1)) } class RuleBasedQueryValidator( rules: List[ValidationRule], - errorsLimit: Option[Int] + override val errorsLimit: Option[Int] ) extends QueryValidator { def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation] = { val ctx = new ValidationContext( diff --git a/modules/core/src/test/scala/sangria/execution/ValueCoercionHelperSpec.scala b/modules/core/src/test/scala/sangria/execution/ValueCoercionHelperSpec.scala index 3a465ea0..3772fd85 100644 --- a/modules/core/src/test/scala/sangria/execution/ValueCoercionHelperSpec.scala +++ b/modules/core/src/test/scala/sangria/execution/ValueCoercionHelperSpec.scala @@ -213,7 +213,9 @@ class ValueCoercionHelperSpec extends AnyWordSpec with Matchers { (), ExceptionHandler.empty, None, - true) + true, + errorsLimit = None + ) val variables = valueCollector .getVariableValues( QueryParser