Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve validators by introducing a dedicated ValidationResult type and simplifying ValidationError #2108

Merged
merged 3 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 41 additions & 79 deletions core/src/main/scala/sttp/tapir/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,94 +77,52 @@ object Validator extends ValidatorMacros {
def enumeration[T](possibleValues: List[T], encode: EncodeToRaw[T], name: Option[SName] = None): Validator.Enumeration[T] =
Enumeration(possibleValues, Some(encode), name)

/** Create a custom validator
* @param doValidate
* Validation function
/** Create a custom validator.
* @param validationLogic
* The logic of the validator
* @param showMessage
* Custom message
* Description of the validator used when invoking [[Validator.show]].
*/
def custom[T](doValidate: T => List[ValidationError[_]], showMessage: Option[String] = None): Validator[T] =
Custom(doValidate, showMessage)
def custom[T](validationLogic: T => ValidationResult, showMessage: Option[String] = None): Validator[T] =
Custom(validationLogic, showMessage)

// ---------- PRIMITIVE ----------
sealed trait Primitive[T] extends Validator[T]
case class Min[T](value: T, exclusive: Boolean)(implicit val valueIsNumeric: Numeric[T]) extends Primitive[T] {
override def apply(t: T): List[ValidationError[_]] = {
if (implicitly[Numeric[T]].gt(t, value) || (!exclusive && implicitly[Numeric[T]].equiv(t, value))) {
List.empty
} else {
List(ValidationError.Primitive(this, t))
}
sealed trait Primitive[T] extends Validator[T] {
def doValidate(t: T): ValidationResult
override def apply(t: T): List[ValidationError[T]] = doValidate(t) match {
case ValidationResult.Valid => Nil
case ValidationResult.Invalid(customMessage) => List(ValidationError(this, t, Nil, customMessage))
}
}
case class Min[T](value: T, exclusive: Boolean)(implicit val valueIsNumeric: Numeric[T]) extends Primitive[T] {
override def doValidate(t: T): ValidationResult =
ValidationResult.validWhen(implicitly[Numeric[T]].gt(t, value) || (!exclusive && implicitly[Numeric[T]].equiv(t, value)))
}
case class Max[T](value: T, exclusive: Boolean)(implicit val valueIsNumeric: Numeric[T]) extends Primitive[T] {
override def apply(t: T): List[ValidationError[_]] = {
if (implicitly[Numeric[T]].lt(t, value) || (!exclusive && implicitly[Numeric[T]].equiv(t, value))) {
List.empty
} else {
List(ValidationError.Primitive(this, t))
}
}
override def doValidate(t: T): ValidationResult =
ValidationResult.validWhen(implicitly[Numeric[T]].lt(t, value) || (!exclusive && implicitly[Numeric[T]].equiv(t, value)))
}
case class Pattern[T <: String](value: String) extends Primitive[T] {
override def apply(t: T): List[ValidationError[_]] = {
if (t.matches(value)) {
List.empty
} else {
List(ValidationError.Primitive(this, t))
}
}
override def doValidate(t: T): ValidationResult = ValidationResult.validWhen(t.matches(value))
}
case class MinLength[T <: String](value: Int) extends Primitive[T] {
override def apply(t: T): List[ValidationError[_]] = {
if (t.size >= value) {
List.empty
} else {
List(ValidationError.Primitive(this, t))
}
}
override def doValidate(t: T): ValidationResult = ValidationResult.validWhen(t.size >= value)
}
case class MaxLength[T <: String](value: Int) extends Primitive[T] {
override def apply(t: T): List[ValidationError[_]] = {
if (t.size <= value) {
List.empty
} else {
List(ValidationError.Primitive(this, t))
}
}
override def doValidate(t: T): ValidationResult = ValidationResult.validWhen(t.size <= value)
}
case class MinSize[T, C[_] <: Iterable[_]](value: Int) extends Primitive[C[T]] {
override def apply(t: C[T]): List[ValidationError[_]] = {
if (t.size >= value) {
List.empty
} else {
List(ValidationError.Primitive(this, t))
}
}
override def doValidate(t: C[T]): ValidationResult = ValidationResult.validWhen(t.size >= value)
}
case class MaxSize[T, C[_] <: Iterable[_]](value: Int) extends Primitive[C[T]] {
override def apply(t: C[T]): List[ValidationError[_]] = {
if (t.size <= value) {
List.empty
} else {
List(ValidationError.Primitive(this, t))
}
}
override def doValidate(t: C[T]): ValidationResult = ValidationResult.validWhen(t.size <= value)
}
case class Custom[T](doValidate: T => List[ValidationError[_]], showMessage: Option[String] = None) extends Validator[T] {
override def apply(t: T): List[ValidationError[_]] = {
doValidate(t)
}
case class Custom[T](validationLogic: T => ValidationResult, showMessage: Option[String] = None) extends Primitive[T] {
override def doValidate(t: T): ValidationResult = validationLogic(t)
}

case class Enumeration[T](possibleValues: List[T], encode: Option[EncodeToRaw[T]], name: Option[SName]) extends Primitive[T] {
override def apply(t: T): List[ValidationError[_]] = {
if (possibleValues.contains(t)) {
List.empty
} else {
List(ValidationError.Primitive(this, t))
}
}
override def doValidate(t: T): ValidationResult = ValidationResult.validWhen(possibleValues.contains(t))

/** Specify how values of this type can be encoded to a raw value (typically a [[String]]). This encoding will be used when generating
* documentation.
Expand Down Expand Up @@ -238,18 +196,22 @@ object Validator extends ValidatorMacros {
}
}

sealed trait ValidationError[T] {
def prependPath(f: FieldName): ValidationError[T]
def invalidValue: T
def path: List[FieldName]
}

object ValidationError {
case class Primitive[T](validator: Validator.Primitive[T], invalidValue: T, path: List[FieldName] = Nil) extends ValidationError[T] {
override def prependPath(f: FieldName): ValidationError[T] = copy(path = f :: path)
sealed trait ValidationResult
object ValidationResult {
case object Valid extends ValidationResult
case class Invalid(customMessage: Option[String] = None) extends ValidationResult
object Invalid {
def apply(customMessage: String): Invalid = Invalid(Some(customMessage))
}

case class Custom[T](invalidValue: T, message: String, path: List[FieldName] = Nil) extends ValidationError[T] {
override def prependPath(f: FieldName): ValidationError[T] = copy(path = f :: path)
}
def validWhen(condition: Boolean): ValidationResult = if (condition) Valid else Invalid()
}

case class ValidationError[T](
validator: Validator.Primitive[T],
invalidValue: T,
path: List[FieldName] = Nil,
customMessage: Option[String] = None
) {
def prependPath(f: FieldName): ValidationError[T] = copy(path = f :: path)
}
40 changes: 18 additions & 22 deletions core/src/test/scala/sttp/tapir/SchemaApplyValidationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {
implicit val schemaForInt: Schema[Int] = Schema.schemaForInt.validate(Validator.min(10))
val schema = implicitly[Schema[Map[String, Int]]]

schema.applyValidation(Map("key" -> 0)).map(noPath(_)) shouldBe List(ValidationError.Primitive(Validator.min(10), 0))
schema.applyValidation(Map("key" -> 0)).map(noPath(_)) shouldBe List(ValidationError(Validator.min(10), 0))
schema.applyValidation(Map("key" -> 12)) shouldBe empty
}

Expand All @@ -24,7 +24,7 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {

schema.applyValidation(None) shouldBe empty
schema.applyValidation(Some(12)) shouldBe empty
schema.applyValidation(Some(5)) shouldBe List(ValidationError.Primitive(Validator.min(10), 5))
schema.applyValidation(Some(5)) shouldBe List(ValidationError(Validator.min(10), 5))
}

it should "validate iterable" in {
Expand All @@ -33,7 +33,7 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {

schema.applyValidation(List.empty[Int]) shouldBe empty
schema.applyValidation(List(11)) shouldBe empty
schema.applyValidation(List(5)) shouldBe List(ValidationError.Primitive(Validator.min(10), 5))
schema.applyValidation(List(5)) shouldBe List(ValidationError(Validator.min(10), 5))
}

it should "validate array" in {
Expand All @@ -42,7 +42,7 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {

schema.applyValidation(Array.empty[Int]) shouldBe empty
schema.applyValidation(Array(11)) shouldBe empty
schema.applyValidation(Array(5)) shouldBe List(ValidationError.Primitive(Validator.min(10), 5))
schema.applyValidation(Array(5)) shouldBe List(ValidationError(Validator.min(10), 5))
}

it should "skip collection validation for array if element validator is passing" in {
Expand Down Expand Up @@ -93,13 +93,13 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {
implicit val ageSchema: Schema[Int] = Schema.schemaForInt.validate(Validator.min(18))
val schema = Schema.derived[Person]
schema.applyValidation(Person("notImportantButOld", 21)).map(noPath(_)) shouldBe List(
ValidationError.Primitive(Validator.pattern("^[A-Z].*"), "notImportantButOld")
ValidationError(Validator.pattern("^[A-Z].*"), "notImportantButOld")
)
schema.applyValidation(Person("notImportantAndYoung", 15)).map(noPath(_)) shouldBe List(
ValidationError.Primitive(Validator.pattern("^[A-Z].*"), "notImportantAndYoung"),
ValidationError.Primitive(Validator.min(18), 15)
ValidationError(Validator.pattern("^[A-Z].*"), "notImportantAndYoung"),
ValidationError(Validator.min(18), 15)
)
schema.applyValidation(Person("ImportantButYoung", 15)).map(noPath(_)) shouldBe List(ValidationError.Primitive(Validator.min(18), 15))
schema.applyValidation(Person("ImportantButYoung", 15)).map(noPath(_)) shouldBe List(ValidationError(Validator.min(18), 15))
schema.applyValidation(Person("ImportantAndOld", 21)) shouldBe empty
}

Expand All @@ -116,15 +116,15 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {

schema.applyValidation(RecursiveName("x", None)) shouldBe Nil
schema.applyValidation(RecursiveName("", None)) shouldBe List(
ValidationError.Primitive(Validator.minLength(1), "", List(FieldName("name")))
ValidationError(Validator.minLength(1), "", List(FieldName("name")))
)
schema.applyValidation(RecursiveName("x", Some(Vector(RecursiveName("x", None))))) shouldBe Nil
schema.applyValidation(RecursiveName("x", Some(Vector(RecursiveName("", None))))) shouldBe List(
ValidationError.Primitive(Validator.minLength(1), "", List(FieldName("subNames"), FieldName("name")))
ValidationError(Validator.minLength(1), "", List(FieldName("subNames"), FieldName("name")))
)
schema.applyValidation(RecursiveName("x", Some(Vector(RecursiveName("x", Some(Vector(RecursiveName("x", None)))))))) shouldBe Nil
schema.applyValidation(RecursiveName("x", Some(Vector(RecursiveName("x", Some(Vector(RecursiveName("", None)))))))) shouldBe List(
ValidationError.Primitive(Validator.minLength(1), "", List(FieldName("subNames"), FieldName("subNames"), FieldName("name")))
ValidationError(Validator.minLength(1), "", List(FieldName("subNames"), FieldName("subNames"), FieldName("name")))
)
}

Expand All @@ -143,8 +143,8 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {
schema.applyValidation(Left(10)) shouldBe Nil
schema.applyValidation(Right("x")) shouldBe Nil

schema.applyValidation(Left(0)) shouldBe List(ValidationError.Primitive(Validator.min(1), 0))
schema.applyValidation(Right("")) shouldBe List(ValidationError.Primitive(Validator.minLength(1), ""))
schema.applyValidation(Left(0)) shouldBe List(ValidationError(Validator.min(1), 0))
schema.applyValidation(Right("")) shouldBe List(ValidationError(Validator.minLength(1), ""))
}

it should "validate mapped either" in {
Expand All @@ -159,14 +159,14 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {
schema.applyValidation(EitherWrapper(Left(10))) shouldBe Nil
schema.applyValidation(EitherWrapper(Right("x"))) shouldBe Nil

schema.applyValidation(EitherWrapper(Left(0))) shouldBe List(ValidationError.Primitive(Validator.min(1), 0))
schema.applyValidation(EitherWrapper(Right(""))) shouldBe List(ValidationError.Primitive(Validator.minLength(1), ""))
schema.applyValidation(EitherWrapper(Left(0))) shouldBe List(ValidationError(Validator.min(1), 0))
schema.applyValidation(EitherWrapper(Right(""))) shouldBe List(ValidationError(Validator.minLength(1), ""))
}

it should "validate oneOf object" in {
case class SomeObject(value: String)
object SomeObject {
implicit def someObjectSchema: Schema[SomeObject] = Schema.derived[SomeObject].validate(Validator.custom(_ => Nil))
implicit def someObjectSchema: Schema[SomeObject] = Schema.derived[SomeObject].validate(Validator.custom(_ => ValidationResult.Valid))
}

sealed trait Entity {
Expand All @@ -180,12 +180,8 @@ class SchemaApplyValidationTest extends AnyFlatSpec with Matchers {
override def kind: String = "person"
}

Entity.entitySchema.applyValidation(Person(SomeObject("1234")))
Entity.entitySchema.applyValidation(Person(SomeObject("1234"))) shouldBe Nil
}

private def noPath[T](v: ValidationError[T]): ValidationError[T] =
v match {
case p: ValidationError.Primitive[T] => p.copy(path = Nil)
case c: ValidationError.Custom[T] => c.copy(path = Nil)
}
private def noPath[T](v: ValidationError[T]): ValidationError[T] = v.copy(path = Nil)
}
Loading