diff --git a/server.go b/server.go index bcf30cb9e3..4e44122454 100644 --- a/server.go +++ b/server.go @@ -1887,6 +1887,8 @@ func (s *Server) Shutdown() error { // // ShutdownWithContext does not close keepalive connections so it's recommended to set ReadTimeout and IdleTimeout // to something else than 0. +// +// When ShutdownWithContext returns errors, any operation to the Server is unavailable. func (s *Server) ShutdownWithContext(ctx context.Context) (err error) { s.mu.Lock() defer s.mu.Unlock() @@ -1898,11 +1900,7 @@ func (s *Server) ShutdownWithContext(ctx context.Context) (err error) { return nil } - for _, ln := range s.ln { - if err = ln.Close(); err != nil { - return err - } - } + lnerr := s.closeListenersLocked() if s.done != nil { close(s.done) @@ -1913,28 +1911,25 @@ func (s *Server) ShutdownWithContext(ctx context.Context) (err error) { // Now we just have to wait until all workers are done or timeout. ticker := time.NewTicker(time.Millisecond * 100) defer ticker.Stop() -END: + for { s.closeIdleConns() if open := atomic.LoadInt32(&s.open); open == 0 { - break + // There may be a pending request to call ctx.Done(). Therefore, we only set it to nil when open == 0. + s.done = nil + return lnerr } // This is not an optimal solution but using a sync.WaitGroup // here causes data races as it's hard to prevent Add() to be called // while Wait() is waiting. select { case <-ctx.Done(): - err = ctx.Err() - break END + return ctx.Err() case <-ticker.C: continue } } - - s.done = nil - s.ln = nil - return err } func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) { @@ -2749,15 +2744,7 @@ func (ctx *RequestCtx) Deadline() (deadline time.Time, ok bool) { // Note: Because creating a new channel for every request is just too expensive, so // RequestCtx.s.done is only closed when the server is shutting down. func (ctx *RequestCtx) Done() <-chan struct{} { - // fix use new variables to prevent panic caused by modifying the original done chan to nil. - done := ctx.s.done - - if done == nil { - done = make(chan struct{}, 1) - done <- struct{}{} - return done - } - return done + return ctx.s.done } // Err returns a non-nil error value after Done is closed, @@ -2934,6 +2921,17 @@ func (s *Server) closeIdleConns() { s.idleConnsMu.Unlock() } +func (s *Server) closeListenersLocked() error { + var err error + for _, ln := range s.ln { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + } + s.ln = nil + return err +} + // A ConnState represents the state of a client connection to a server. // It's used by the optional Server.ConnState hook. type ConnState int diff --git a/server_race_test.go b/server_race_test.go new file mode 100644 index 0000000000..a409cd804e --- /dev/null +++ b/server_race_test.go @@ -0,0 +1,46 @@ +//go:build race + +package fasthttp + +import ( + "context" + "github.com/valyala/fasthttp/fasthttputil" + "math" + "testing" +) + +func TestServerDoneRace(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + for i := 0; i < math.MaxInt; i++ { + ctx.Done() + } + }, + } + + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer c.Close() + if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: go.dev\r\nContent-Length: 3\r\n\r\nABC" + + "\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it + "GET / HTTP/1.1\r\nHost: go.dev\r\n\r\n")); err != nil { + t.Fatal(err) + } + ctx, cancelFunc := context.WithCancel(context.Background()) + cancelFunc() + + s.ShutdownWithContext(ctx) +}