Skip to content

Commit

Permalink
Support calling Serve multiple times on a Server (#731)
Browse files Browse the repository at this point in the history
You can use the following methods in the handler to find out which
listener the connection is coming in on.
RequestCtx.IsTLS()
RequestCtx.LocalAddr()
RequestCtx.Request.Header.Host()
  • Loading branch information
erikdubbelboer authored Jan 23, 2020
1 parent 03813ae commit b0102c9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 11 deletions.
26 changes: 15 additions & 11 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ type Server struct {
writerPool sync.Pool
hijackConnPool sync.Pool

// We need to know our listener so we can close it in Shutdown().
ln net.Listener
// We need to know our listeners so we can close them in Shutdown().
ln []net.Listener

mu sync.Mutex
open int32
Expand Down Expand Up @@ -1577,20 +1577,21 @@ func (s *Server) Serve(ln net.Listener) error {
var c net.Conn
var err error

maxWorkersCount := s.getConcurrency()

s.mu.Lock()
{
if s.ln != nil {
s.mu.Unlock()
return ErrAlreadyServing
s.ln = append(s.ln, ln)
if s.done == nil {
s.done = make(chan struct{})
}

s.ln = ln
s.done = make(chan struct{})
if s.concurrencyCh == nil {
s.concurrencyCh = make(chan struct{}, maxWorkersCount)
}
}
s.mu.Unlock()

maxWorkersCount := s.getConcurrency()
s.concurrencyCh = make(chan struct{}, maxWorkersCount)
wp := &workerPool{
WorkerFunc: s.serveConn,
MaxWorkersCount: maxWorkersCount,
Expand Down Expand Up @@ -1663,8 +1664,10 @@ func (s *Server) Shutdown() error {
return nil
}

if err := s.ln.Close(); err != nil {
return err
for _, ln := range s.ln {
if err := ln.Close(); err != nil {
return err
}
}

if s.done != nil {
Expand All @@ -1684,6 +1687,7 @@ func (s *Server) Shutdown() error {
time.Sleep(time.Millisecond * 100)
}

s.done = nil
s.ln = nil
return nil
}
Expand Down
44 changes: 44 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3047,6 +3047,50 @@ func TestShutdownErr(t *testing.T) {
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}

func TestMultipleServe(t *testing.T) {
t.Parallel()

s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
},
}

ln1 := fasthttputil.NewInmemoryListener()
ln2 := fasthttputil.NewInmemoryListener()

go func() {
if err := s.Serve(ln1); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()
go func() {
if err := s.Serve(ln2); err != nil {
t.Errorf("unexepcted error: %s", err)
}
}()

conn, err := ln1.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")

conn, err = ln2.Dial()
if err != nil {
t.Fatalf("unexepcted error: %s", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %s", err)
}
br = bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}

func TestMaxBodySizePerRequest(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit b0102c9

Please sign in to comment.