Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT / WIP: Limit number of errors returned (when variables are used) #1022

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions modules/core/src/main/scala/sangria/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import sangria.ast
import sangria.ast.SourceMapper
import sangria.marshalling.{InputUnmarshaller, ResultMarshaller}
import sangria.schema._
import sangria.validation.QueryValidator
import sangria.validation.{QueryValidator, RuleBasedQueryValidator}
import InputUnmarshaller.emptyMapVars
import sangria.execution.deferred.DeferredResolver

import scala.concurrent.{ExecutionContext, Future}
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -43,7 +44,9 @@ case class Executor[Ctx, Root](
userContext,
exceptionHandler,
scalarMiddleware,
false)(um)
false,
errorsLimit = queryValidator.errorsLimit
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to define errorsLimit only once (e.g. on the queryValidator) and re-use in other places

)(um)

val executionResult = for {
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
Expand Down Expand Up @@ -155,7 +158,9 @@ case class Executor[Ctx, Root](
userContext,
exceptionHandler,
scalarMiddleware,
false)(um)
false,
errorsLimit = queryValidator.errorsLimit
)(um)

val executionResult = for {
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ case class InputDocumentMaterializer[Vars](
(),
ExceptionHandler.empty,
None,
false)(iu)
false,
errorsLimit = None
)(iu)

val violations = QueryValidator.default.validateInputDocument(schema, document, inputType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ object QueryReducerExecutor {
userContext,
exceptionHandler,
scalarMiddleware,
true)(InputUnmarshaller.scalaInputUnmarshaller[Any @@ ScalaInput])
true,
errorsLimit = queryValidator.errorsLimit
)(InputUnmarshaller.scalaInputUnmarshaller[Any @@ ScalaInput])

val executionResult = for {
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ import sangria.util.Cache
import sangria.validation._

import scala.collection.immutable.VectorBuilder
import scala.collection.mutable.ListBuffer

class ValueCoercionHelper[Ctx](
sourceMapper: Option[SourceMapper] = None,
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
userContext: Option[Ctx] = None) {
userContext: Option[Ctx] = None,
errorsLimit: Option[Int] = None) {
import ValueCoercionHelper.defaultValueMapFn

private def resolveListValue(
Expand Down Expand Up @@ -580,98 +582,108 @@ class ValueCoercionHelper[Ctx](
nodeLocation.toList ++ firstValue.toList
}

def isValidValue[In](tpe: InputType[_], input: Option[In])(implicit
um: InputUnmarshaller[In]): Vector[Violation] = (tpe, input) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(ofType, Some(value))
case (OptionInputType(_), _) => Vector.empty
case (_, None) => Vector(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isListNode(values) =>
um.getListValue(values)
.toVector
.flatMap(v =>
def isValidValue[In](tpe: InputType[_], input: Option[In], errors: ListBuffer[Violation])(implicit
um: InputUnmarshaller[In]): Vector[Violation] = {
// TODO: NEED TO ADD VALUES TO `errors` to make it work properly
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code doesn't work correctly because I don't mutate errors buffer, but this is one of the options how it can be implemented


if (errorsLimit.exists(_ <= errors.length)) Vector.empty
else
(tpe, input) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(ofType, Some(value), errors)
case (OptionInputType(_), _) => Vector.empty
case (_, None) => Vector(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isListNode(values) =>
um.getListValue(values)
.toVector
.flatMap(v =>
isValidValue(
ofType,
v match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
},
errors).map(ListValueViolation(0, _, sourceMapper, Nil)))

case (ListInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(
ofType,
v match {
value match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil)))
},
errors).map(ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) =>
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f) =>
UnknownInputObjectFieldViolation(
SchemaRenderer.renderTypeName(objTpe, true),
f,
sourceMapper,
Nil)
}

case (ListInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(
ofType,
value match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) =>
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f) =>
UnknownInputObjectFieldViolation(
SchemaRenderer.renderTypeName(objTpe, true),
f,
sourceMapper,
Nil)
}
val fieldViolations =
objTpe.fields.toVector.flatMap(f =>
isValidValue(f.fieldType, um.getMapValue(valueMap, f.name), errors)
.map(MapValueViolation(f.name, _, sourceMapper, Nil)))

val fieldViolations =
objTpe.fields.toVector.flatMap(f =>
isValidValue(f.fieldType, um.getMapValue(valueMap, f.name))
.map(MapValueViolation(f.name, _, sourceMapper, Nil)))
fieldViolations ++ unknownFields

fieldViolations ++ unknownFields
case (objTpe: InputObjectType[_], _) =>
Vector(
InputObjectIsOfWrongTypeMissingViolation(
SchemaRenderer.renderTypeName(objTpe, true),
sourceMapper,
Nil))

case (objTpe: InputObjectType[_], _) =>
Vector(
InputObjectIsOfWrongTypeMissingViolation(
SchemaRenderer.renderTypeName(objTpe, true),
sourceMapper,
Nil))
case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.coerceInput(node)
case other => scalar.coerceUserInput(other)
}

case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.coerceInput(node)
case other => scalar.coerceUserInput(other)
}
coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}

coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.aliasFor.coerceInput(node)
case other => scalar.aliasFor.coerceUserInput(other)
}

case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.aliasFor.coerceInput(node)
case other => scalar.aliasFor.coerceUserInput(other)
}
coerced match {
case Left(violation) => Vector(violation)
case Right(v) =>
scalar.fromScalar(v) match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
}

case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => enumT.coerceInput(node)
case other => enumT.coerceUserInput(other)
}

coerced match {
case Left(violation) => Vector(violation)
case Right(v) =>
scalar.fromScalar(v) match {
coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
}

case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => enumT.coerceInput(node)
case other => enumT.coerceUserInput(other)
}
case (enumT: EnumType[_], Some(value)) =>
Vector(EnumCoercionViolation)

coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
case _ =>
Vector(GenericInvalidValueViolation(sourceMapper, Nil))
}

case (enumT: EnumType[_], Some(value)) =>
Vector(EnumCoercionViolation)

case _ =>
Vector(GenericInvalidValueViolation(sourceMapper, Nil))
}

def getVariableValue[In](
Expand All @@ -680,7 +692,8 @@ class ValueCoercionHelper[Ctx](
input: Option[In],
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]])(implicit
um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = {
val violations = isValidValue(tpe, input)
val errors = ListBuffer[Violation]()
val violations = isValidValue(tpe, input, errors)

if (violations.isEmpty) {
val fieldPath = s"$$${definition.name}" :: Nil
Expand Down
14 changes: 11 additions & 3 deletions modules/core/src/main/scala/sangria/execution/ValueCollector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ class ValueCollector[Ctx, Input](
userContext: Ctx,
exceptionHandler: ExceptionHandler,
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]],
ignoreErrors: Boolean)(implicit um: InputUnmarshaller[Input]) {
ignoreErrors: Boolean,
errorsLimit: Option[Int])(implicit um: InputUnmarshaller[Input]) {
val coercionHelper =
new ValueCoercionHelper[Ctx](sourceMapper, deprecationTracker, Some(userContext))
new ValueCoercionHelper[Ctx](sourceMapper, deprecationTracker, Some(userContext), errorsLimit)

private val argumentCache =
Cache.empty[(ExecutionPath.PathCacheKeyReversed, Vector[ast.Argument]), Try[Args]]
Expand Down Expand Up @@ -63,7 +64,14 @@ class ValueCollector[Ctx, Input](
}
}

val (errors, values) = res.partition(_._2.isLeft)
val (allErrors, values) = res.partition(_._2.isLeft)

// additionally reduce the number of errors
// although we limit the number of errors per input, it's still possible to exceed the limit if errors are combined.
val errors = errorsLimit match {
case Some(limit) => allErrors.take(limit)
case _ => allErrors
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like there are other places where violations could be returned, but I'm not sure if they should be updated as well.


if (errors.nonEmpty)
Failure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import scala.reflect.{ClassTag, classTag}

trait QueryValidator {
def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation]
def errorsLimit: Option[Int]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making it public so it can be re-used?

}

object QueryValidator {
Expand Down Expand Up @@ -54,14 +55,15 @@ object QueryValidator {
val empty: QueryValidator = new QueryValidator {
def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation] =
Vector.empty
override val errorsLimit: Option[Int] = None
}

val default: RuleBasedQueryValidator = ruleBased(allRules, errorsLimit = Some(10))
val default: RuleBasedQueryValidator = ruleBased(allRules, errorsLimit = Some(1))
}

class RuleBasedQueryValidator(
rules: List[ValidationRule],
errorsLimit: Option[Int]
override val errorsLimit: Option[Int]
) extends QueryValidator {
def validateQuery(schema: Schema[_, _], queryAst: ast.Document): Vector[Violation] = {
val ctx = new ValidationContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ class ValueCoercionHelperSpec extends AnyWordSpec with Matchers {
(),
ExceptionHandler.empty,
None,
true)
true,
errorsLimit = None
)
val variables = valueCollector
.getVariableValues(
QueryParser
Expand Down