diff --git a/server.go b/server.go index 5fa4e12898..5f9c9a215a 100644 --- a/server.go +++ b/server.go @@ -167,6 +167,12 @@ type Server struct { // * ErrBrokenChunks ErrorHandler func(ctx *RequestCtx, err error) + // PanicHandler for reacting on a panic. Called with the result of calling recover() if it is not nil. + // + // To be on the safe side, implementors are advised to re-panic so the stack would continue to unwind. + // This would make sure you're not left in some inconsistent state due to the original panic. + PanicHandler func(r interface{}) + // HeaderReceived is called after receiving the header // // non zero RequestConfig field values will overwrite the default configs @@ -1578,6 +1584,7 @@ func (s *Server) Serve(ln net.Listener) error { WorkerFunc: s.serveConn, MaxWorkersCount: maxWorkersCount, LogAllErrors: s.LogAllErrors, + PanicHandler: s.PanicHandler, Logger: s.logger(), connState: s.setState, } @@ -2147,14 +2154,20 @@ func (s *Server) setState(nc net.Conn, state ConnState) { func hijackConnHandler(r io.Reader, c net.Conn, s *Server, h HijackHandler) { hjc := s.acquireHijackConn(r, c) - h(hjc) - if br, ok := r.(*bufio.Reader); ok { - releaseReader(s, br) + if s.PanicHandler != nil { + defer func() { + s.cleanAfterHijackConn(r, c, hjc) + if r := recover(); r != nil { + s.PanicHandler(r) + } + }() } - if !s.KeepHijackedConns { - c.Close() - s.releaseHijackConn(hjc) + + h(hjc) + + if s.PanicHandler == nil { + s.cleanAfterHijackConn(r, c, hjc) } } @@ -2180,6 +2193,16 @@ func (s *Server) releaseHijackConn(hjc *hijackConn) { s.hijackConnPool.Put(hjc) } +func (s *Server) cleanAfterHijackConn(r io.Reader, c net.Conn, hjc *hijackConn) { + if br, ok := r.(*bufio.Reader); ok { + releaseReader(s, br) + } + if !s.KeepHijackedConns { + c.Close() + s.releaseHijackConn(hjc) + } +} + type hijackConn struct { net.Conn r io.Reader diff --git a/workerpool.go b/workerpool.go index bfd297c31e..37c922caf6 100644 --- a/workerpool.go +++ b/workerpool.go @@ -22,6 +22,8 @@ type workerPool struct { LogAllErrors bool + PanicHandler func(r interface{}) + MaxIdleWorkerDuration time.Duration Logger Logger @@ -200,9 +202,27 @@ func (wp *workerPool) release(ch *workerChan) bool { return true } +func (wp *workerPool) workerDone() { + wp.lock.Lock() + wp.workersCount-- + wp.lock.Unlock() +} + func (wp *workerPool) workerFunc(ch *workerChan) { var c net.Conn + if wp.PanicHandler != nil { + defer func() { + wp.workerDone() + if r := recover(); r != nil { + if c != nil { + c.Close() + } + wp.PanicHandler(r) + } + }() + } + var err error for c = range ch.ch { if c == nil { @@ -231,7 +251,7 @@ func (wp *workerPool) workerFunc(ch *workerChan) { } } - wp.lock.Lock() - wp.workersCount-- - wp.lock.Unlock() + if wp.PanicHandler == nil { + wp.workerDone() + } } diff --git a/workerpool_test.go b/workerpool_test.go index 05e1be0bdb..da4bad0e61 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -1,8 +1,10 @@ package fasthttp import ( + "fmt" "io/ioutil" "net" + "sync/atomic" "testing" "time" @@ -167,3 +169,111 @@ func testWorkerPoolMaxWorkersCount(t *testing.T) { } wp.Stop() } + +func TestWorkerPoolPanicErrorSerial(t *testing.T) { + testWorkerPoolPanicErrorMulti(t) +} + +func TestWorkerPoolPanicErrorConcurrent(t *testing.T) { + concurrency := 10 + ch := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + testWorkerPoolPanicErrorMulti(t) + ch <- struct{}{} + }() + } + for i := 0; i < concurrency; i++ { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + } +} + +func testWorkerPoolPanicErrorMulti(t *testing.T) { + var globalCount uint64 + var recoverCount uint64 + wp := &workerPool{ + WorkerFunc: func(conn net.Conn) error { + count := atomic.AddUint64(&globalCount, 1) + switch count % 3 { + case 0: + panic("foobar") + case 1: + return fmt.Errorf("fake error") + } + return nil + }, + MaxWorkersCount: 1000, + MaxIdleWorkerDuration: time.Millisecond, + Logger: &customLogger{}, + PanicHandler: func(r interface{}) { + if r == nil { + t.Fatalf("PanicHandler got nil") + } + atomic.AddUint64(&recoverCount, 1) + }, + } + + for i := 0; i < 10; i++ { + testWorkerPoolPanicError(t, wp) + } + + if recoverCount == 0 { + t.Fatalf("PanicHandler was not called") + } +} + +func testWorkerPoolPanicError(t *testing.T, wp *workerPool) { + wp.Start() + + ln := fasthttputil.NewInmemoryListener() + + clientsCount := 10 + clientCh := make(chan struct{}, clientsCount) + for i := 0; i < clientsCount; i++ { + go func() { + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + data, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if len(data) > 0 { + t.Fatalf("unexpected data read: %q. Expecting empty data", data) + } + if err = conn.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + clientCh <- struct{}{} + }() + } + + for i := 0; i < clientsCount; i++ { + conn, err := ln.Accept() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !wp.Serve(conn) { + t.Fatalf("worker pool mustn't be full") + } + } + + for i := 0; i < clientsCount; i++ { + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + wp.Stop() +}