From 2b39142b16e69fa9405d5c7401dc62a5776a9bf5 Mon Sep 17 00:00:00 2001 From: Onsi Fakhouri Date: Wed, 25 Oct 2023 10:51:49 -0600 Subject: [PATCH] MatchError can now take an optional func(error) bool + description --- docs/index.md | 15 +++++---- matchers.go | 37 ++++++++++++++++---- matchers/match_error_matcher.go | 25 ++++++++++++-- matchers/match_error_matcher_test.go | 50 ++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 15 deletions(-) diff --git a/docs/index.md b/docs/index.md index 8a5dd9ba5..5fefcae87 100644 --- a/docs/index.md +++ b/docs/index.md @@ -812,18 +812,19 @@ where `FUNCTION()` is a function call that returns an error-type as its *first o #### MatchError(expected interface{}) ```go -Ω(ACTUAL).Should(MatchError(EXPECTED)) +Ω(ACTUAL).Should(MatchError(EXPECTED, )) ``` succeeds if `ACTUAL` is a non-nil `error` that matches `EXPECTED`. `EXPECTED` must be one of the following: -- A string, in which case `ACTUAL.Error()` will be compared against `EXPECTED`. -- A matcher, in which case `ACTUAL.Error()` is tested against the matcher. -- An error, in which case any of the following is satisfied: - - `errors.Is(ACTUAL, EXPECTED)` returns `true` - - `ACTUAL` or any of the errors it wraps (directly or indirectly) equals `EXPECTED` in terms of `reflect.DeepEqual()`. +- A string, in which case the matcher asserts that `ACTUAL.Error() == EXPECTED` +- An error (i.e. anything satisfying Go's `error` interface). In which case the matcher: + - First checks if `errors.Is(ACTUAL, EXPECTED)` returns `true` + - If not, it checks if `ACTUAL` or any of the errors it wraps (directly or indirectly) equals `EXPECTED` via `reflect.DeepEqual()`. +- A matcher, in which case `ACTUAL.Error()` is tested against the matcher, for example `Expect(err).Should(MatchError(ContainSubstring("sprocket not found")))` will pass if `err.Error()` has the substring "sprocke tnot found" +- A function with signature `func(error) bool`. The matcher then passes if `f(ACTUAL)` returns `true`. If using a function in this way you are required to pass a `FUNCTION_ERROR_DESCRIPTION` argument to `MatchError` that describes the function. This description is used in the failure message. For example: `Expect(err).To(MatchError(os.IsNotExist, "IsNotExist))` -Any other type for `EXPECTED` is an error. It is also an error for `ACTUAL` to be nil. +Any other type for `EXPECTED` is an error. It is also an error for `ACTUAL` to be nil. Note that `FUNCTION_ERROR_DESCRIPTION` is a description of the error function, if used. This is required when passing a function but is ignored in all other cases. ### Working with Channels diff --git a/matchers.go b/matchers.go index 88f100432..cd3f431d2 100644 --- a/matchers.go +++ b/matchers.go @@ -88,19 +88,44 @@ func Succeed() types.GomegaMatcher { } // MatchError succeeds if actual is a non-nil error that matches the passed in -// string, error, or matcher. +// string, error, function, or matcher. // // These are valid use-cases: // -// Expect(err).Should(MatchError("an error")) //asserts that err.Error() == "an error" -// Expect(err).Should(MatchError(SomeError)) //asserts that err == SomeError (via reflect.DeepEqual) -// Expect(err).Should(MatchError(ContainSubstring("sprocket not found"))) // asserts that err.Error() contains substring "sprocket not found" +// When passed a string: +// +// Expect(err).To(MatchError("an error")) +// +// asserts that err.Error() == "an error" +// +// When passed an error: +// +// Expect(err).To(MatchError(SomeError)) +// +// First checks if errors.Is(err, SomeError). +// If that fails then it checks if reflect.DeepEqual(err, SomeError) repeatedly for err and any errors wrapped by err +// +// When passed a matcher: +// +// Expect(err).To(MatchError(ContainSubstring("sprocket not found"))) +// +// the matcher is passed err.Error(). In this case it asserts that err.Error() contains substring "sprocket not found" +// +// When passed a func(err) bool and a description: +// +// Expect(err).To(MatchError(os.IsNotExist, "IsNotExist")) +// +// the function is passed err and matches if the return value is true. The description is required to allow Gomega +// to print a useful error message. // // It is an error for err to be nil or an object that does not implement the // Error interface -func MatchError(expected interface{}) types.GomegaMatcher { +// +// The optional second argument is a description of the error function, if used. This is required when passing a function but is ignored in all other cases. +func MatchError(expected interface{}, functionErrorDescription ...any) types.GomegaMatcher { return &matchers.MatchErrorMatcher{ - Expected: expected, + Expected: expected, + FuncErrDescription: functionErrorDescription, } } diff --git a/matchers/match_error_matcher.go b/matchers/match_error_matcher.go index 827475ea5..c539dd389 100644 --- a/matchers/match_error_matcher.go +++ b/matchers/match_error_matcher.go @@ -9,10 +9,14 @@ import ( ) type MatchErrorMatcher struct { - Expected interface{} + Expected any + FuncErrDescription []any + isFunc bool } -func (matcher *MatchErrorMatcher) Match(actual interface{}) (success bool, err error) { +func (matcher *MatchErrorMatcher) Match(actual any) (success bool, err error) { + matcher.isFunc = false + if isNil(actual) { return false, fmt.Errorf("Expected an error, got nil") } @@ -42,6 +46,17 @@ func (matcher *MatchErrorMatcher) Match(actual interface{}) (success bool, err e return actualErr.Error() == expected, nil } + v := reflect.ValueOf(expected) + t := v.Type() + errorInterface := reflect.TypeOf((*error)(nil)).Elem() + if t.Kind() == reflect.Func && t.NumIn() == 1 && t.In(0).Implements(errorInterface) && t.NumOut() == 1 && t.Out(0).Kind() == reflect.Bool { + if len(matcher.FuncErrDescription) == 0 { + return false, fmt.Errorf("MatchError requires an additional description when passed a function") + } + matcher.isFunc = true + return v.Call([]reflect.Value{reflect.ValueOf(actualErr)})[0].Bool(), nil + } + var subMatcher omegaMatcher var hasSubMatcher bool if expected != nil { @@ -57,9 +72,15 @@ func (matcher *MatchErrorMatcher) Match(actual interface{}) (success bool, err e } func (matcher *MatchErrorMatcher) FailureMessage(actual interface{}) (message string) { + if matcher.isFunc { + return format.Message(actual, fmt.Sprintf("to match error function %s", matcher.FuncErrDescription[0])) + } return format.Message(actual, "to match error", matcher.Expected) } func (matcher *MatchErrorMatcher) NegatedFailureMessage(actual interface{}) (message string) { + if matcher.isFunc { + return format.Message(actual, fmt.Sprintf("not to match error function %s", matcher.FuncErrDescription[0])) + } return format.Message(actual, "not to match error", matcher.Expected) } diff --git a/matchers/match_error_matcher_test.go b/matchers/match_error_matcher_test.go index 57bd64f58..666ea5e89 100644 --- a/matchers/match_error_matcher_test.go +++ b/matchers/match_error_matcher_test.go @@ -85,6 +85,36 @@ var _ = Describe("MatchErrorMatcher", func() { }) }) + When("passed a function that takes error and returns bool", func() { + var IsFooError = func(err error) bool { + return err.Error() == "foo" + } + + It("requires an additional description", func() { + _, err := (&MatchErrorMatcher{ + Expected: IsFooError, + }).Match(errors.New("foo")) + Expect(err).Should(MatchError("MatchError requires an additional description when passed a function")) + }) + + It("matches iff the function returns true", func() { + Ω(errors.New("foo")).Should(MatchError(IsFooError, "FooError")) + Ω(errors.New("fooo")).ShouldNot(MatchError(IsFooError, "FooError")) + }) + + It("uses the error description to construct its message", func() { + failuresMessages := InterceptGomegaFailures(func() { + Ω(errors.New("fooo")).Should(MatchError(IsFooError, "FooError")) + }) + Ω(failuresMessages[0]).Should(ContainSubstring("fooo\n {s: \"fooo\"}\nto match error function FooError")) + + failuresMessages = InterceptGomegaFailures(func() { + Ω(errors.New("foo")).ShouldNot(MatchError(IsFooError, "FooError")) + }) + Ω(failuresMessages[0]).Should(ContainSubstring("foo\n {s: \"foo\"}\nnot to match error function FooError")) + }) + }) + It("should fail when passed anything else", func() { actualErr := errors.New("an error") _, err := (&MatchErrorMatcher{ @@ -96,6 +126,26 @@ var _ = Describe("MatchErrorMatcher", func() { Expected: 3, }).Match(actualErr) Expect(err).Should(HaveOccurred()) + + _, err = (&MatchErrorMatcher{ + Expected: func(e error) {}, + }).Match(actualErr) + Expect(err).Should(HaveOccurred()) + + _, err = (&MatchErrorMatcher{ + Expected: func() bool { return false }, + }).Match(actualErr) + Expect(err).Should(HaveOccurred()) + + _, err = (&MatchErrorMatcher{ + Expected: func() {}, + }).Match(actualErr) + Expect(err).Should(HaveOccurred()) + + _, err = (&MatchErrorMatcher{ + Expected: func(e error, a string) (bool, error) { return false, nil }, + }).Match(actualErr) + Expect(err).Should(HaveOccurred()) }) })