diff --git a/error.go b/error.go index d78ea58..3a828b2 100644 --- a/error.go +++ b/error.go @@ -142,6 +142,7 @@ package multierr // import "go.uber.org/multierr" import ( "bytes" + "errors" "fmt" "io" "strings" @@ -234,6 +235,17 @@ func (merr *multiError) Error() string { return result } +// Every compares every error in the given err against the given target error +// using [errors.Is], and returns true only if every comparison returned true. +func Every(err error, target error) bool { + for _, e := range extractErrors(err) { + if !errors.Is(e, target) { + return false + } + } + return true +} + func (merr *multiError) Format(f fmt.State, c rune) { if c == 'v' && f.Flag('+') { merr.writeMultiline(f) diff --git a/error_post_go120_test.go b/error_post_go120_test.go index ed35289..39873c2 100644 --- a/error_post_go120_test.go +++ b/error_post_go120_test.go @@ -40,3 +40,26 @@ func TestErrorsOnErrorsJoin(t *testing.T) { assert.Equal(t, err1, errs[0]) assert.Equal(t, err2, errs[1]) } + +func TestEveryWithErrorsJoin(t *testing.T) { + myError1 := errors.New("woeful misfortune") + myError2 := errors.New("worrisome travesty") + + t.Run("all match", func(t *testing.T) { + err := errors.Join(myError1, myError1, myError1) + + assert.True(t, errors.Is(err, myError1)) + assert.True(t, Every(err, myError1)) + assert.False(t, errors.Is(err, myError2)) + assert.False(t, Every(err, myError2)) + }) + + t.Run("one matches", func(t *testing.T) { + err := errors.Join(myError1, myError2) + + assert.True(t, errors.Is(err, myError1)) + assert.False(t, Every(err, myError1)) + assert.True(t, errors.Is(err, myError2)) + assert.False(t, Every(err, myError2)) + }) +} diff --git a/error_test.go b/error_test.go index f0bb17a..2e02d04 100644 --- a/error_test.go +++ b/error_test.go @@ -59,6 +59,67 @@ func newMultiErr(errors ...error) error { return &multiError{errors: errors} } +func TestEvery(t *testing.T) { + myError1 := errors.New("woeful misfortune") + myError2 := errors.New("worrisome travesty") + + for _, tt := range []struct { + desc string + giveErr error + giveTarget error + wantIs bool + wantEvery bool + }{ + { + desc: "all match", + giveErr: newMultiErr(myError1, myError1, myError1), + giveTarget: myError1, + wantIs: true, + wantEvery: true, + }, + { + desc: "one matches", + giveErr: newMultiErr(myError1, myError2), + giveTarget: myError1, + wantIs: true, + wantEvery: false, + }, + { + desc: "not multiErrs and non equal", + giveErr: myError1, + giveTarget: myError2, + wantIs: false, + wantEvery: false, + }, + { + desc: "not multiErrs but equal", + giveErr: myError1, + giveTarget: myError1, + wantIs: true, + wantEvery: true, + }, + { + desc: "not multiErr w multiErr target", + giveErr: myError1, + giveTarget: newMultiErr(myError1, myError1), + wantIs: false, + wantEvery: false, + }, + { + desc: "multiErr w multiErr target", + giveErr: newMultiErr(myError1, myError1), + giveTarget: newMultiErr(myError1, myError1), + wantIs: false, + wantEvery: false, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + assert.Equal(t, tt.wantIs, errors.Is(tt.giveErr, tt.giveTarget)) + assert.Equal(t, tt.wantEvery, Every(tt.giveErr, tt.giveTarget)) + }) + } +} + func TestCombine(t *testing.T) { tests := []struct { // Input