diff --git a/app.go b/app.go index 8189f9d71..ef3d35a0d 100644 --- a/app.go +++ b/app.go @@ -704,7 +704,6 @@ func (app *App) start(ctx context.Context) error { if err := app.lifecycle.Start(ctx); err != nil { return err } - app.receivers.Start(ctx) return nil }) } @@ -742,6 +741,7 @@ func (app *App) Stop(ctx context.Context) (err error) { // Alternatively, a signal can be broadcast to all done channels manually by // using the Shutdown functionality (see the [Shutdowner] documentation for details). func (app *App) Done() <-chan os.Signal { + app.receivers.Start() // No-op if running return app.receivers.Done() } @@ -752,6 +752,7 @@ func (app *App) Done() <-chan os.Signal { // in the [ShutdownSignal] struct. // Otherwise, the signal that was received will be set. func (app *App) Wait() <-chan ShutdownSignal { + app.receivers.Start() // No-op if running return app.receivers.Wait() } diff --git a/app_internal_test.go b/app_internal_test.go index efdc2f33a..c9b49f8ab 100644 --- a/app_internal_test.go +++ b/app_internal_test.go @@ -21,8 +21,10 @@ package fx import ( + "context" "errors" "fmt" + "os" "sync" "testing" @@ -115,3 +117,24 @@ func TestAnnotationError(t *testing.T) { assert.ErrorIs(t, err, wantErr) assert.Contains(t, err.Error(), wantErr.Error()) } + +// TestStartDoesNotRegisterSignals verifies that signal.Notify is not called +// when a user starts an app. signal.Notify should only be called when the +// .Wait/.Done are called. Note that app.Run calls .Wait() implicitly. +func TestStartDoesNotRegisterSignals(t *testing.T) { + app := New() + calledNotify := false + + // Mock notify function to spy when this is called. + app.receivers.notify = func(c chan<- os.Signal, sig ...os.Signal) { + calledNotify = true + } + app.receivers.stopNotify = func(c chan<- os.Signal) {} + + app.Start(context.Background()) + defer app.Stop(context.Background()) + assert.False(t, calledNotify, "notify should not be called when app starts") + + _ = app.Wait() // User signals intent have fx listen for signals. This should call notify + assert.True(t, calledNotify, "notify should be called after Wait") +} diff --git a/app_test.go b/app_test.go index 0a46d915c..198000cdc 100644 --- a/app_test.go +++ b/app_test.go @@ -2331,7 +2331,9 @@ func TestHookConstructors(t *testing.T) { func TestDone(t *testing.T) { t.Parallel() - done := fxtest.New(t).Done() + app := fxtest.New(t) + defer app.RequireStop() + done := app.Done() require.NotNil(t, done, "Got a nil channel.") select { case sig := <-done: @@ -2340,6 +2342,38 @@ func TestDone(t *testing.T) { } } +// TestShutdownThenWait tests that if we call .Shutdown before waiting, the wait +// will still return the last shutdown signal. +func TestShutdownThenWait(t *testing.T) { + t.Parallel() + + var ( + s Shutdowner + stopped bool + ) + app := fxtest.New( + t, + Populate(&s), + Invoke(func(lc Lifecycle) { + lc.Append(StopHook(func() { + stopped = true + })) + }), + ).RequireStart() + require.NotNil(t, s) + + err := s.Shutdown(ExitCode(1337)) + assert.NoError(t, err) + assert.False(t, stopped) + + shutdownSig := <-app.Wait() + assert.Equal(t, 1337, shutdownSig.ExitCode) + assert.False(t, stopped) + + app.RequireStop() + assert.True(t, stopped) +} + func TestReplaceLogger(t *testing.T) { t.Parallel() diff --git a/shutdown_test.go b/shutdown_test.go index 49956196a..c4971ffc2 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -116,6 +116,7 @@ func TestShutdown(t *testing.T) { ) require.NoError(t, app.Start(context.Background()), "error starting app") + t.Cleanup(func() { app.Stop(context.Background()) }) // in t.Cleanup so this happens after all subtests return (not just this function) defer require.NoError(t, app.Stop(context.Background())) for i := 0; i < 10; i++ { diff --git a/signal.go b/signal.go index 1b8456899..595a847bc 100644 --- a/signal.go +++ b/signal.go @@ -102,7 +102,7 @@ func (recv *signalReceivers) running() bool { return recv.shutdown != nil && recv.finished != nil } -func (recv *signalReceivers) Start(ctx context.Context) { +func (recv *signalReceivers) Start() { recv.m.Lock() defer recv.m.Unlock() diff --git a/signal_test.go b/signal_test.go index 95d6fe458..18d96f479 100644 --- a/signal_test.go +++ b/signal_test.go @@ -74,9 +74,7 @@ func TestSignal(t *testing.T) { t.Parallel() t.Run("timeout", func(t *testing.T) { recv := newSignalReceivers() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - recv.Start(ctx) + recv.Start() timeoutCtx, cancel := context.WithTimeout(context.Background(), 0) defer cancel() err := recv.Stop(timeoutCtx) @@ -86,8 +84,8 @@ func TestSignal(t *testing.T) { recv := newSignalReceivers() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - recv.Start(ctx) - recv.Start(ctx) // should be a no-op if already running + recv.Start() + recv.Start() // should be a no-op if already running require.NoError(t, recv.Stop(ctx)) }) t.Run("notify", func(t *testing.T) { @@ -106,7 +104,7 @@ func TestSignal(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - recv.Start(ctx) + recv.Start() stub <- syscall.SIGTERM stub <- syscall.SIGTERM require.Equal(t, syscall.SIGTERM, <-recv.Done())