From e2091c59c9f2213621eb5d4bf27ab3ca5f6a2d01 Mon Sep 17 00:00:00 2001 From: Onsi Fakhouri Date: Mon, 10 Oct 2022 14:08:44 -0600 Subject: [PATCH] Eventually/Consistently will forward an attached context to functions that ask for one --- internal/async_assertion.go | 190 +++++++++++++++++++++---------- internal/async_assertion_test.go | 54 +++++++-- 2 files changed, 171 insertions(+), 73 deletions(-) diff --git a/internal/async_assertion.go b/internal/async_assertion.go index defa80512..5674889a8 100644 --- a/internal/async_assertion.go +++ b/internal/async_assertion.go @@ -2,7 +2,6 @@ package internal import ( "context" - "errors" "fmt" "reflect" "runtime" @@ -19,12 +18,21 @@ const ( AsyncAssertionTypeConsistently ) +func (at AsyncAssertionType) String() string { + switch at { + case AsyncAssertionTypeEventually: + return "Eventually" + case AsyncAssertionTypeConsistently: + return "Consistently" + } + return "INVALID ASYNC ASSERTION TYPE" +} + type AsyncAssertion struct { asyncType AsyncAssertionType actualIsFunc bool - actualValue interface{} - actualFunc func() ([]reflect.Value, error) + actual interface{} timeoutInterval time.Duration pollingInterval time.Duration @@ -43,49 +51,9 @@ func NewAsyncAssertion(asyncType AsyncAssertionType, actualInput interface{}, g g: g, } - switch actualType := reflect.TypeOf(actualInput); { - case actualInput == nil || actualType.Kind() != reflect.Func: - out.actualValue = actualInput - case actualType.NumIn() == 0 && actualType.NumOut() > 0: + out.actual = actualInput + if actualInput != nil && reflect.TypeOf(actualInput).Kind() == reflect.Func { out.actualIsFunc = true - out.actualFunc = func() ([]reflect.Value, error) { - return reflect.ValueOf(actualInput).Call([]reflect.Value{}), nil - } - case actualType.NumIn() == 1 && actualType.In(0).Implements(reflect.TypeOf((*types.Gomega)(nil)).Elem()): - out.actualIsFunc = true - out.actualFunc = func() (values []reflect.Value, err error) { - var assertionFailure error - assertionCapturingGomega := NewGomega(g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) { - skip := 0 - if len(callerSkip) > 0 { - skip = callerSkip[0] - } - _, file, line, _ := runtime.Caller(skip + 1) - assertionFailure = fmt.Errorf("Assertion in callback at %s:%d failed:\n%s", file, line, message) - panic("stop execution") - }) - - defer func() { - if actualType.NumOut() == 0 { - if assertionFailure == nil { - values = []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} - } else { - values = []reflect.Value{reflect.ValueOf(assertionFailure)} - } - } else { - err = assertionFailure - } - if e := recover(); e != nil && assertionFailure == nil { - panic(e) - } - }() - - values = reflect.ValueOf(actualInput).Call([]reflect.Value{reflect.ValueOf(assertionCapturingGomega)}) - return - } - default: - msg := fmt.Sprintf("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes %d arguments and returns %d values.", actualType.NumIn(), actualType.NumOut()) - g.Fail(msg, offset+4) } return out @@ -145,25 +113,115 @@ func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interfa return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n" } -func (assertion *AsyncAssertion) pollActual() (interface{}, error) { - if !assertion.actualIsFunc { - return assertion.actualValue, nil +func (assertion *AsyncAssertion) processReturnValues(values []reflect.Value) (interface{}, error) { + if len(values) == 0 { + return nil, fmt.Errorf("No values were returned by the function passed to Gomega") } + actual := values[0].Interface() + for i, extraValue := range values[1:] { + extra := extraValue.Interface() + if extra == nil { + continue + } + zero := reflect.Zero(extraValue.Type()).Interface() + if reflect.DeepEqual(extra, zero) { + continue + } + return actual, fmt.Errorf("Unexpected non-nil/non-zero argument at index %d:\n\t<%T>: %#v", i+1, extra, extra) + } + return actual, nil +} - values, err := assertion.actualFunc() - if err != nil { - return nil, err +var gomegaType = reflect.TypeOf((*types.Gomega)(nil)).Elem() +var contextType = reflect.TypeOf(new(context.Context)).Elem() + +func (assertion *AsyncAssertion) invalidFunctionError(t reflect.Type) error { + return fmt.Errorf(`The function passed to %s had an invalid signature of %s. Functions passed to %s must either: + + (a) have return values or + (b) take a Gomega interface as their first argument and use that Gomega instance to make assertions. + +You can learn more at https://onsi.github.io/gomega/#eventually +`, assertion.asyncType, t, assertion.asyncType) +} + +func (assertion *AsyncAssertion) noConfiguredContextForFunctionError(t reflect.Type) error { + return fmt.Errorf(`The function passed to %s requested a context.Context, but no context has been provided to %s. Please pass one in using %s().WithContext(). + +You can learn more at https://onsi.github.io/gomega/#eventually +`, assertion.asyncType, t, assertion.asyncType) +} + +func (assertion *AsyncAssertion) buildActualPoller() (func() (interface{}, error), error) { + if !assertion.actualIsFunc { + return func() (interface{}, error) { return assertion.actual, nil }, nil + } + actualValue := reflect.ValueOf(assertion.actual) + actualType := reflect.TypeOf(assertion.actual) + numIn, numOut := actualType.NumIn(), actualType.NumOut() + + if numIn == 0 && numOut == 0 { + return nil, assertion.invalidFunctionError(actualType) + } else if numIn == 0 { + return func() (interface{}, error) { return assertion.processReturnValues(actualValue.Call([]reflect.Value{})) }, nil + } + takesGomega, takesContext := actualType.In(0).Implements(gomegaType), actualType.In(0).Implements(contextType) + if takesGomega && numIn > 1 && actualType.In(1).Implements(contextType) { + takesContext = true + } + if !takesGomega && numOut == 0 { + return nil, assertion.invalidFunctionError(actualType) + } + if takesContext && assertion.ctx == nil { + return nil, assertion.noConfiguredContextForFunctionError(actualType) } - extras := []interface{}{nil} - for _, value := range values[1:] { - extras = append(extras, value.Interface()) + remainingIn := numIn + if takesGomega { + remainingIn -= 1 } - success, message := vetActuals(extras, 0) - if !success { - return nil, errors.New(message) + if takesContext { + remainingIn -= 1 + } + if remainingIn > 0 { + return nil, assertion.invalidFunctionError(actualType) + } + + var assertionFailure error + inValues := []reflect.Value{} + if takesGomega { + inValues = append(inValues, reflect.ValueOf(NewGomega(assertion.g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) { + skip := 0 + if len(callerSkip) > 0 { + skip = callerSkip[0] + } + _, file, line, _ := runtime.Caller(skip + 1) + assertionFailure = fmt.Errorf("Assertion in callback at %s:%d failed:\n%s", file, line, message) + panic("stop execution") + }))) + } + if takesContext { + inValues = append(inValues, reflect.ValueOf(assertion.ctx)) } - return values[0].Interface(), nil + return func() (actual interface{}, err error) { + var values []reflect.Value + assertionFailure = nil + defer func() { + if numOut == 0 { + actual = assertionFailure + } else { + actual, err = assertion.processReturnValues(values) + if assertionFailure != nil { + err = assertionFailure + } + } + if e := recover(); e != nil && assertionFailure == nil { + panic(e) + } + }() + values = actualValue.Call(inValues) + return + }, nil } func (assertion *AsyncAssertion) matcherMayChange(matcher types.GomegaMatcher, value interface{}) bool { @@ -186,14 +244,20 @@ func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch var err error mayChange := true - value, err := assertion.pollActual() + assertion.g.THelper() + + pollActual, err := assertion.buildActualPoller() + if err != nil { + assertion.g.Fail(err.Error(), 2+assertion.offset) + return false + } + + value, err := pollActual() if err == nil { mayChange = assertion.matcherMayChange(matcher, value) matches, err = matcher.Match(value) } - assertion.g.THelper() - messageGenerator := func() string { // can be called out of band by Ginkgo if the user requests a progress report lock.Lock() @@ -240,7 +304,7 @@ func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch select { case <-time.After(assertion.pollingInterval): - v, e := assertion.pollActual() + v, e := pollActual() lock.Lock() value, err = v, e lock.Unlock() @@ -272,7 +336,7 @@ func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch select { case <-time.After(assertion.pollingInterval): - v, e := assertion.pollActual() + v, e := pollActual() lock.Lock() value, err = v, e lock.Unlock() diff --git a/internal/async_assertion_test.go b/internal/async_assertion_test.go index 806fd40ca..bad53dacc 100644 --- a/internal/async_assertion_test.go +++ b/internal/async_assertion_test.go @@ -769,21 +769,55 @@ var _ = Describe("Asynchronous Assertions", func() { }) }) + Context("when passed a function that takes a context", func() { + It("forwards its own configured context", func() { + ctx := context.WithValue(context.Background(), "key", "value") + Eventually(func(ctx context.Context) string { + return ctx.Value("key").(string) + }).WithContext(ctx).Should(Equal("value")) + }) + + It("forwards its own configured context _and_ a Gomega if requested", func() { + ctx := context.WithValue(context.Background(), "key", "value") + Eventually(func(g Gomega, ctx context.Context) { + g.Expect(ctx.Value("key").(string)).To(Equal("schmalue")) + }).WithContext(ctx).Should(MatchError(ContainSubstring("Expected\n : value\nto equal\n : schmalue"))) + }) + + Context("when the assertion does not have an attached context", func() { + It("errors", func() { + ig.G.Eventually(func(ctx context.Context) string { + return ctx.Value("key").(string) + }).Should(Equal("value")) + Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually requested a context.Context, but no context has been provided to func(context.Context) string. Please pass one in using Eventually().WithContext().")) + Ω(ig.FailureSkip).Should(Equal([]int{2})) + }) + }) + }) + Describe("when passed an invalid function", func() { - It("errors immediately", func() { - ig.G.Eventually(func() {}) - Ω(ig.FailureMessage).Should(Equal("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes 0 arguments and returns 0 values.")) - Ω(ig.FailureSkip).Should(Equal([]int{4})) + It("errors with a failure", func() { + ig.G.Eventually(func() {}).Should(Equal("foo")) + Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually had an invalid signature of func()")) + Ω(ig.FailureSkip).Should(Equal([]int{2})) + + ig.G.Consistently(func(ctx context.Context) {}).Should(Equal("foo")) + Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Consistently had an invalid signature of func(context.Context)")) + Ω(ig.FailureSkip).Should(Equal([]int{2})) ig = NewInstrumentedGomega() - ig.G.Eventually(func(g Gomega, foo string) {}) - Ω(ig.FailureMessage).Should(Equal("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes 2 arguments and returns 0 values.")) - Ω(ig.FailureSkip).Should(Equal([]int{4})) + ig.G.Eventually(func(g Gomega, foo string) {}).Should(Equal("foo")) + Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually had an invalid signature of func(types.Gomega, string)")) + Ω(ig.FailureSkip).Should(Equal([]int{2})) + + ig.G.Eventually(func(ctx context.Context, g Gomega) {}).Should(Equal("foo")) + Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually had an invalid signature of func(context.Context, types.Gomega)")) + Ω(ig.FailureSkip).Should(Equal([]int{2})) ig = NewInstrumentedGomega() - ig.G.Eventually(func(foo string) {}) - Ω(ig.FailureMessage).Should(Equal("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes 1 arguments and returns 0 values.")) - Ω(ig.FailureSkip).Should(Equal([]int{4})) + ig.G.Eventually(func(foo string) {}).Should(Equal("foo")) + Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually had an invalid signature of func(string)")) + Ω(ig.FailureSkip).Should(Equal([]int{2})) }) }) })