From c29a62eb5c232103e5ab47831e8681fab4389df7 Mon Sep 17 00:00:00 2001 From: JacobOaks Date: Tue, 13 Dec 2022 13:53:12 -0500 Subject: [PATCH] Opt-In Panic Recovery (#364) * Initial feature implementation * move tests * Add test for recoverFromPanicsOption.String() * Fix small comments --- constructor.go | 13 ++++++++- container.go | 19 ++++++++++++ container_test.go | 6 ++++ decorate.go | 13 ++++++++- dig_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++ error.go | 54 ++++++++++++++++++++++++++++++++++- invoke.go | 17 ++++++++++- scope.go | 4 +++ 8 files changed, 195 insertions(+), 4 deletions(-) diff --git a/constructor.go b/constructor.go index cf58bec8..7cf0c8ef 100644 --- a/constructor.go +++ b/constructor.go @@ -130,7 +130,7 @@ func (n *constructorNode) String() string { // Call calls this constructor if it hasn't already been called and // injects any values produced by it into the provided container. -func (n *constructorNode) Call(c containerStore) error { +func (n *constructorNode) Call(c containerStore) (err error) { if n.called { return nil } @@ -142,6 +142,17 @@ func (n *constructorNode) Call(c containerStore) error { } } + if n.s.recoverFromPanics { + defer func() { + if p := recover(); p != nil { + err = PanicError{ + fn: n.location, + Panic: p, + } + } + }() + } + args, err := n.paramList.BuildList(c) if err != nil { return errArgumentsFailed{ diff --git a/container.go b/container.go index 57d11cbf..19fc65f3 100644 --- a/container.go +++ b/container.go @@ -176,6 +176,25 @@ func (deferAcyclicVerificationOption) applyOption(c *Container) { c.scope.deferAcyclicVerification = true } +// RecoverFromPanics is an [Option] to recover from panics that occur while +// running functions given to the container. When set, recovered panics +// will be placed into a [PanicError], and returned at the invoke callsite. +// See [PanicError] for an example on how to handle panics with this option +// enabled, and distinguish them from errors. +func RecoverFromPanics() Option { + return recoverFromPanicsOption{} +} + +type recoverFromPanicsOption struct{} + +func (recoverFromPanicsOption) String() string { + return "RecoverFromPanics()" +} + +func (recoverFromPanicsOption) applyOption(c *Container) { + c.scope.recoverFromPanics = true +} + // Changes the source of randomness for the container. // // This will help provide determinism during tests. diff --git a/container_test.go b/container_test.go index 612172f6..e75a41cc 100644 --- a/container_test.go +++ b/container_test.go @@ -61,4 +61,10 @@ func TestOptionStrings(t *testing.T) { assert.Equal(t, "DryRun(true)", fmt.Sprint(DryRun(true))) assert.Equal(t, "DryRun(false)", fmt.Sprint(DryRun(false))) }) + + t.Run("RecoverFromPanics()", func(t *testing.T) { + t.Parallel() + + assert.Equal(t, "RecoverFromPanics()", fmt.Sprint(RecoverFromPanics())) + }) } diff --git a/decorate.go b/decorate.go index c7e79a69..09c84a1e 100644 --- a/decorate.go +++ b/decorate.go @@ -95,7 +95,7 @@ func newDecoratorNode(dcor interface{}, s *Scope) (*decoratorNode, error) { return n, nil } -func (n *decoratorNode) Call(s containerStore) error { +func (n *decoratorNode) Call(s containerStore) (err error) { if n.state == decoratorCalled { return nil } @@ -109,6 +109,17 @@ func (n *decoratorNode) Call(s containerStore) error { } } + if n.s.recoverFromPanics { + defer func() { + if p := recover(); p != nil { + err = PanicError{ + fn: n.location, + Panic: p, + } + } + }() + } + args, err := n.params.BuildList(n.s) if err != nil { return errArgumentsFailed{ diff --git a/dig_test.go b/dig_test.go index d609cd43..162eb823 100644 --- a/dig_test.go +++ b/dig_test.go @@ -1493,6 +1493,79 @@ func TestGroups(t *testing.T) { // --- END OF END TO END TESTS +func TestRecoverFromPanic(t *testing.T) { + + tests := []struct { + name string + setup func(*digtest.Container) + invoke interface{} + wantErr []string + }{ + { + name: "panic in provided function", + setup: func(c *digtest.Container) { + c.RequireProvide(func() int { + panic("terrible sadness") + }) + }, + invoke: func(i int) {}, + wantErr: []string{ + `could not build arguments for function "go.uber.org/dig_test".TestRecoverFromPanic.\S+`, + `failed to build int:`, + `panic: "terrible sadness" in func: "go.uber.org/dig_test".TestRecoverFromPanic.\S+`, + }, + }, + { + name: "panic in decorator", + setup: func(c *digtest.Container) { + c.RequireProvide(func() string { return "" }) + c.RequireDecorate(func(s string) string { + panic("great sadness") + }) + }, + invoke: func(s string) {}, + wantErr: []string{ + `could not build arguments for function "go.uber.org/dig_test".TestRecoverFromPanic.\S+`, + `failed to build string:`, + `panic: "great sadness" in func: "go.uber.org/dig_test".TestRecoverFromPanic.\S+`, + }, + }, + { + name: "panic in invoke", + setup: func(c *digtest.Container) {}, + invoke: func() { panic("terrible woe") }, + wantErr: []string{ + `panic: "terrible woe" in func: "go.uber.org/dig_test".TestRecoverFromPanic.\S+`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + t.Run("without option", func(t *testing.T) { + c := digtest.New(t) + tt.setup(c) + assert.Panics(t, func() { c.Container.Invoke(tt.invoke) }, + "expected panic without dig.RecoverFromPanics() option", + ) + }) + + t.Run("with option", func(t *testing.T) { + c := digtest.New(t, dig.RecoverFromPanics()) + tt.setup(c) + err := c.Container.Invoke(tt.invoke) + require.Error(t, err) + dig.AssertErrorMatches(t, err, tt.wantErr[0], tt.wantErr[1:]...) + var pe dig.PanicError + assert.True(t, errors.As(err, &pe), "expected error chain to contain a PanicError") + _, ok := dig.RootCause(err).(dig.PanicError) + assert.True(t, ok, "expected root cause to be a PanicError") + }) + }) + } +} + func TestProvideConstructorErrors(t *testing.T) { t.Run("multiple-type constructor returns multiple objects of same type", func(t *testing.T) { c := digtest.New(t) diff --git a/error.go b/error.go index 430e4881..24c5685b 100644 --- a/error.go +++ b/error.go @@ -56,6 +56,55 @@ type digError interface { fmt.Formatter } +// A PanicError occurs when a panic occurs while running functions given to the container +// with the [RecoverFromPanic] option being set. It contains the panic message from the +// original panic. A PanicError does not wrap other errors, and it does not implement +// dig.Error, meaning it will be returned from [RootCause]. With the [RecoverFromPanic] +// option set, a panic can be distinguished from dig errors and errors from provided/ +// invoked/decorated functions like so: +// +// rootCause := dig.RootCause(err) +// +// var pe dig.PanicError +// var de dig.Error +// if errors.As(rootCause, &pe) { +// // This is caused by a panic +// } else if errors.As(err, &de) { +// // This is a dig error +// } else { +// // This is an error from one of my provided/invoked functions or decorators +// } +// +// Or, if only interested in distinguishing panics from errors: +// +// var pe dig.PanicError +// if errors.As(err, &pe) { +// // This is caused by a panic +// } else { +// // This is an error +// } +type PanicError struct { + + // The function the panic occurred at + fn *digreflect.Func + + // The panic that was returned from recover() + Panic any +} + +// Format will format the PanicError, expanding the corresponding function if in +v mode. +func (e PanicError) Format(w fmt.State, c rune) { + if w.Flag('+') && c == 'v' { + fmt.Fprintf(w, "panic: %q in func: %+v", e.Panic, e.fn) + } else { + fmt.Fprintf(w, "panic: %q in func: %v", e.Panic, e.fn) + } +} + +func (e PanicError) Error() string { + return fmt.Sprint(e) +} + // formatError will call a dig.Error's writeMessage() method to print the error message // and then will automatically attempt to print errors wrapped underneath (which can create // a recursive effect if the wrapped error's Format() method then points back to this function). @@ -96,8 +145,11 @@ func formatError(e digError, w fmt.State, v rune) { // if errors.As(rootCause, &de) { // // Is a Dig error // } else { -// // Is an error thrown by one of my provided or invoked functions +// // Is an error thrown by one of my provided/invoked/decorated functions // } +// +// See [PanicError] for an example showing how to additionally detect +// and handle panics in provided/invoked/decorated functions. func RootCause(err error) error { var de Error // Dig down to first non dig.Error, or bottom of chain diff --git a/invoke.go b/invoke.go index 69a2f70e..432210cd 100644 --- a/invoke.go +++ b/invoke.go @@ -42,6 +42,10 @@ type InvokeOption interface { // // The function may return an error to indicate failure. The error will be // returned to the caller as-is. +// +// If the [RecoverFromPanics] option was given to the container and a panic +// occurs when invoking, a [PanicError] with the panic contained will be +// returned. See [PanicError] for more info. func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { return c.scope.Invoke(function, opts...) } @@ -54,7 +58,7 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { // // The function may return an error to indicate failure. The error will be // returned to the caller as-is. -func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error { +func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) (err error) { ftype := reflect.TypeOf(function) if ftype == nil { return newErrInvalidInput("can't invoke an untyped nil", nil) @@ -90,6 +94,17 @@ func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error { Reason: err, } } + if s.recoverFromPanics { + defer func() { + if p := recover(); p != nil { + err = PanicError{ + fn: digreflect.InspectFunc(function), + Panic: p, + } + } + }() + } + returned := s.invokerFn(reflect.ValueOf(function), args) if len(returned) == 0 { return nil diff --git a/scope.go b/scope.go index 0c6498eb..216cf18a 100644 --- a/scope.go +++ b/scope.go @@ -75,6 +75,9 @@ type Scope struct { // Defer acyclic check on provide until Invoke. deferAcyclicVerification bool + // Recover from panics in user-provided code and wrap in an exported error type. + recoverFromPanics bool + // invokerFn calls a function with arguments provided to Provide or Invoke. invokerFn invokerFn @@ -115,6 +118,7 @@ func (s *Scope) Scope(name string, opts ...ScopeOption) *Scope { child.parentScope = s child.invokerFn = s.invokerFn child.deferAcyclicVerification = s.deferAcyclicVerification + child.recoverFromPanics = s.recoverFromPanics // child copies the parent's graph nodes. child.gh.nodes = append(child.gh.nodes, s.gh.nodes...)