Skip to content

Commit

Permalink
Eventually/Consistently will forward an attached context to functions…
Browse files Browse the repository at this point in the history
… that ask for one
  • Loading branch information
onsi committed Oct 10, 2022
1 parent 2e34979 commit e2091c5
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 73 deletions.
190 changes: 127 additions & 63 deletions internal/async_assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package internal

import (
"context"
"errors"
"fmt"
"reflect"
"runtime"
Expand All @@ -19,12 +18,21 @@ const (
AsyncAssertionTypeConsistently
)

func (at AsyncAssertionType) String() string {
switch at {
case AsyncAssertionTypeEventually:
return "Eventually"
case AsyncAssertionTypeConsistently:
return "Consistently"
}
return "INVALID ASYNC ASSERTION TYPE"
}

type AsyncAssertion struct {
asyncType AsyncAssertionType

actualIsFunc bool
actualValue interface{}
actualFunc func() ([]reflect.Value, error)
actual interface{}

timeoutInterval time.Duration
pollingInterval time.Duration
Expand All @@ -43,49 +51,9 @@ func NewAsyncAssertion(asyncType AsyncAssertionType, actualInput interface{}, g
g: g,
}

switch actualType := reflect.TypeOf(actualInput); {
case actualInput == nil || actualType.Kind() != reflect.Func:
out.actualValue = actualInput
case actualType.NumIn() == 0 && actualType.NumOut() > 0:
out.actual = actualInput
if actualInput != nil && reflect.TypeOf(actualInput).Kind() == reflect.Func {
out.actualIsFunc = true
out.actualFunc = func() ([]reflect.Value, error) {
return reflect.ValueOf(actualInput).Call([]reflect.Value{}), nil
}
case actualType.NumIn() == 1 && actualType.In(0).Implements(reflect.TypeOf((*types.Gomega)(nil)).Elem()):
out.actualIsFunc = true
out.actualFunc = func() (values []reflect.Value, err error) {
var assertionFailure error
assertionCapturingGomega := NewGomega(g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) {
skip := 0
if len(callerSkip) > 0 {
skip = callerSkip[0]
}
_, file, line, _ := runtime.Caller(skip + 1)
assertionFailure = fmt.Errorf("Assertion in callback at %s:%d failed:\n%s", file, line, message)
panic("stop execution")
})

defer func() {
if actualType.NumOut() == 0 {
if assertionFailure == nil {
values = []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())}
} else {
values = []reflect.Value{reflect.ValueOf(assertionFailure)}
}
} else {
err = assertionFailure
}
if e := recover(); e != nil && assertionFailure == nil {
panic(e)
}
}()

values = reflect.ValueOf(actualInput).Call([]reflect.Value{reflect.ValueOf(assertionCapturingGomega)})
return
}
default:
msg := fmt.Sprintf("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes %d arguments and returns %d values.", actualType.NumIn(), actualType.NumOut())
g.Fail(msg, offset+4)
}

return out
Expand Down Expand Up @@ -145,25 +113,115 @@ func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interfa
return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n"
}

func (assertion *AsyncAssertion) pollActual() (interface{}, error) {
if !assertion.actualIsFunc {
return assertion.actualValue, nil
func (assertion *AsyncAssertion) processReturnValues(values []reflect.Value) (interface{}, error) {
if len(values) == 0 {
return nil, fmt.Errorf("No values were returned by the function passed to Gomega")
}
actual := values[0].Interface()
for i, extraValue := range values[1:] {
extra := extraValue.Interface()
if extra == nil {
continue
}
zero := reflect.Zero(extraValue.Type()).Interface()
if reflect.DeepEqual(extra, zero) {
continue
}
return actual, fmt.Errorf("Unexpected non-nil/non-zero argument at index %d:\n\t<%T>: %#v", i+1, extra, extra)
}
return actual, nil
}

values, err := assertion.actualFunc()
if err != nil {
return nil, err
var gomegaType = reflect.TypeOf((*types.Gomega)(nil)).Elem()
var contextType = reflect.TypeOf(new(context.Context)).Elem()

func (assertion *AsyncAssertion) invalidFunctionError(t reflect.Type) error {
return fmt.Errorf(`The function passed to %s had an invalid signature of %s. Functions passed to %s must either:
(a) have return values or
(b) take a Gomega interface as their first argument and use that Gomega instance to make assertions.
You can learn more at https://onsi.github.io/gomega/#eventually
`, assertion.asyncType, t, assertion.asyncType)
}

func (assertion *AsyncAssertion) noConfiguredContextForFunctionError(t reflect.Type) error {
return fmt.Errorf(`The function passed to %s requested a context.Context, but no context has been provided to %s. Please pass one in using %s().WithContext().
You can learn more at https://onsi.github.io/gomega/#eventually
`, assertion.asyncType, t, assertion.asyncType)
}

func (assertion *AsyncAssertion) buildActualPoller() (func() (interface{}, error), error) {
if !assertion.actualIsFunc {
return func() (interface{}, error) { return assertion.actual, nil }, nil
}
actualValue := reflect.ValueOf(assertion.actual)
actualType := reflect.TypeOf(assertion.actual)
numIn, numOut := actualType.NumIn(), actualType.NumOut()

if numIn == 0 && numOut == 0 {
return nil, assertion.invalidFunctionError(actualType)
} else if numIn == 0 {
return func() (interface{}, error) { return assertion.processReturnValues(actualValue.Call([]reflect.Value{})) }, nil
}
takesGomega, takesContext := actualType.In(0).Implements(gomegaType), actualType.In(0).Implements(contextType)
if takesGomega && numIn > 1 && actualType.In(1).Implements(contextType) {
takesContext = true
}
if !takesGomega && numOut == 0 {
return nil, assertion.invalidFunctionError(actualType)
}
if takesContext && assertion.ctx == nil {
return nil, assertion.noConfiguredContextForFunctionError(actualType)
}
extras := []interface{}{nil}
for _, value := range values[1:] {
extras = append(extras, value.Interface())
remainingIn := numIn
if takesGomega {
remainingIn -= 1
}
success, message := vetActuals(extras, 0)
if !success {
return nil, errors.New(message)
if takesContext {
remainingIn -= 1
}
if remainingIn > 0 {
return nil, assertion.invalidFunctionError(actualType)
}

var assertionFailure error
inValues := []reflect.Value{}
if takesGomega {
inValues = append(inValues, reflect.ValueOf(NewGomega(assertion.g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) {
skip := 0
if len(callerSkip) > 0 {
skip = callerSkip[0]
}
_, file, line, _ := runtime.Caller(skip + 1)
assertionFailure = fmt.Errorf("Assertion in callback at %s:%d failed:\n%s", file, line, message)
panic("stop execution")
})))
}
if takesContext {
inValues = append(inValues, reflect.ValueOf(assertion.ctx))
}

return values[0].Interface(), nil
return func() (actual interface{}, err error) {
var values []reflect.Value
assertionFailure = nil
defer func() {
if numOut == 0 {
actual = assertionFailure
} else {
actual, err = assertion.processReturnValues(values)
if assertionFailure != nil {
err = assertionFailure
}
}
if e := recover(); e != nil && assertionFailure == nil {
panic(e)
}
}()
values = actualValue.Call(inValues)
return
}, nil
}

func (assertion *AsyncAssertion) matcherMayChange(matcher types.GomegaMatcher, value interface{}) bool {
Expand All @@ -186,14 +244,20 @@ func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch
var err error
mayChange := true

value, err := assertion.pollActual()
assertion.g.THelper()

pollActual, err := assertion.buildActualPoller()
if err != nil {
assertion.g.Fail(err.Error(), 2+assertion.offset)
return false
}

value, err := pollActual()
if err == nil {
mayChange = assertion.matcherMayChange(matcher, value)
matches, err = matcher.Match(value)
}

assertion.g.THelper()

messageGenerator := func() string {
// can be called out of band by Ginkgo if the user requests a progress report
lock.Lock()
Expand Down Expand Up @@ -240,7 +304,7 @@ func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch

select {
case <-time.After(assertion.pollingInterval):
v, e := assertion.pollActual()
v, e := pollActual()
lock.Lock()
value, err = v, e
lock.Unlock()
Expand Down Expand Up @@ -272,7 +336,7 @@ func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch

select {
case <-time.After(assertion.pollingInterval):
v, e := assertion.pollActual()
v, e := pollActual()
lock.Lock()
value, err = v, e
lock.Unlock()
Expand Down
54 changes: 44 additions & 10 deletions internal/async_assertion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,21 +769,55 @@ var _ = Describe("Asynchronous Assertions", func() {
})
})

Context("when passed a function that takes a context", func() {
It("forwards its own configured context", func() {
ctx := context.WithValue(context.Background(), "key", "value")
Eventually(func(ctx context.Context) string {
return ctx.Value("key").(string)
}).WithContext(ctx).Should(Equal("value"))
})

It("forwards its own configured context _and_ a Gomega if requested", func() {
ctx := context.WithValue(context.Background(), "key", "value")
Eventually(func(g Gomega, ctx context.Context) {
g.Expect(ctx.Value("key").(string)).To(Equal("schmalue"))
}).WithContext(ctx).Should(MatchError(ContainSubstring("Expected\n <string>: value\nto equal\n <string>: schmalue")))
})

Context("when the assertion does not have an attached context", func() {
It("errors", func() {
ig.G.Eventually(func(ctx context.Context) string {
return ctx.Value("key").(string)
}).Should(Equal("value"))
Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually requested a context.Context, but no context has been provided to func(context.Context) string. Please pass one in using Eventually().WithContext()."))
Ω(ig.FailureSkip).Should(Equal([]int{2}))
})
})
})

Describe("when passed an invalid function", func() {
It("errors immediately", func() {
ig.G.Eventually(func() {})
Ω(ig.FailureMessage).Should(Equal("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes 0 arguments and returns 0 values."))
Ω(ig.FailureSkip).Should(Equal([]int{4}))
It("errors with a failure", func() {
ig.G.Eventually(func() {}).Should(Equal("foo"))
Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually had an invalid signature of func()"))
Ω(ig.FailureSkip).Should(Equal([]int{2}))

ig.G.Consistently(func(ctx context.Context) {}).Should(Equal("foo"))
Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Consistently had an invalid signature of func(context.Context)"))
Ω(ig.FailureSkip).Should(Equal([]int{2}))

ig = NewInstrumentedGomega()
ig.G.Eventually(func(g Gomega, foo string) {})
Ω(ig.FailureMessage).Should(Equal("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes 2 arguments and returns 0 values."))
Ω(ig.FailureSkip).Should(Equal([]int{4}))
ig.G.Eventually(func(g Gomega, foo string) {}).Should(Equal("foo"))
Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually had an invalid signature of func(types.Gomega, string)"))
Ω(ig.FailureSkip).Should(Equal([]int{2}))

ig.G.Eventually(func(ctx context.Context, g Gomega) {}).Should(Equal("foo"))
Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually had an invalid signature of func(context.Context, types.Gomega)"))
Ω(ig.FailureSkip).Should(Equal([]int{2}))

ig = NewInstrumentedGomega()
ig.G.Eventually(func(foo string) {})
Ω(ig.FailureMessage).Should(Equal("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes 1 arguments and returns 0 values."))
Ω(ig.FailureSkip).Should(Equal([]int{4}))
ig.G.Eventually(func(foo string) {}).Should(Equal("foo"))
Ω(ig.FailureMessage).Should(ContainSubstring("The function passed to Eventually had an invalid signature of func(string)"))
Ω(ig.FailureSkip).Should(Equal([]int{2}))
})
})
})
Expand Down

0 comments on commit e2091c5

Please sign in to comment.