From 8713335f548d2c18effe36c9686f5a88b65eefce Mon Sep 17 00:00:00 2001 From: Kazushi Kitaya Date: Thu, 5 Sep 2019 00:57:51 +0900 Subject: [PATCH] Fix data race in fasthttputil.pipeConn (#645) * add tests for fasthttputil.InmemoryListener * fix data race in pipeConn * update use of readDeadlineChLock --- fasthttputil/inmemory_listener_test.go | 92 ++++++++++++++++++++++++++ fasthttputil/pipeconns.go | 12 +++- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/fasthttputil/inmemory_listener_test.go b/fasthttputil/inmemory_listener_test.go index 86aab68e1e..19cec0c80a 100644 --- a/fasthttputil/inmemory_listener_test.go +++ b/fasthttputil/inmemory_listener_test.go @@ -2,7 +2,13 @@ package fasthttputil import ( "bytes" + "context" "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "sync" "testing" "time" ) @@ -90,3 +96,89 @@ func TestInmemoryListener(t *testing.T) { t.Fatalf("timeout") } } + +// echoServerHandler implements http.Handler. +type echoServerHandler struct { + t *testing.T +} + +func (s *echoServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + time.Sleep(time.Millisecond * 100) + if _, err := io.Copy(w, r.Body); err != nil { + s.t.Fatalf("unexpected error: %s", err) + } +} + +func testInmemoryListenerHTTP(t *testing.T, f func(t *testing.T, client *http.Client)) { + ln := NewInmemoryListener() + defer ln.Close() + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return ln.Dial() + }, + }, + Timeout: time.Second, + } + + server := &http.Server{ + Handler: &echoServerHandler{t}, + } + + go func() { + if err := server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Fatalf("unexpected error: %s", err) + } + }() + + f(t, client) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + server.Shutdown(ctx) +} + +func testInmemoryListenerHTTPSingle(t *testing.T, client *http.Client, content string) { + res, err := client.Post("http://...", "text/plain", bytes.NewBufferString(content)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + s := string(b) + if string(b) != content { + t.Fatalf("unexpected response %s, expecting %s", s, content) + } +} + +func TestInmemoryListenerHTTPSingle(t *testing.T) { + testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) { + testInmemoryListenerHTTPSingle(t, client, "request") + }) +} + +func TestInmemoryListenerHTTPSerial(t *testing.T) { + testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) { + for i := 0; i < 10; i++ { + testInmemoryListenerHTTPSingle(t, client, fmt.Sprintf("request_%d", i)) + } + }) +} + +func TestInmemoryListenerHTTPConcurrent(t *testing.T) { + testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) { + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + testInmemoryListenerHTTPSingle(t, client, fmt.Sprintf("request_%d", i)) + }(i) + } + wg.Wait() + }) +} diff --git a/fasthttputil/pipeconns.go b/fasthttputil/pipeconns.go index aa92b6ff8d..c6ca39dd85 100644 --- a/fasthttputil/pipeconns.go +++ b/fasthttputil/pipeconns.go @@ -87,6 +87,8 @@ type pipeConn struct { readDeadlineCh <-chan time.Time writeDeadlineCh <-chan time.Time + + readDeadlineChLock sync.Mutex } func (c *pipeConn) Write(p []byte) (int, error) { @@ -158,9 +160,12 @@ func (c *pipeConn) readNextByteBuffer(mayBlock bool) error { if !mayBlock { return errWouldBlock } + c.readDeadlineChLock.Lock() + readDeadlineCh := c.readDeadlineCh + c.readDeadlineChLock.Unlock() select { case c.b = <-c.rCh: - case <-c.readDeadlineCh: + case <-readDeadlineCh: c.readDeadlineCh = closedDeadlineCh // rCh may contain data when deadline is reached. // Read the data before returning ErrTimeout. @@ -214,7 +219,10 @@ func (c *pipeConn) SetReadDeadline(deadline time.Time) error { if c.readDeadlineTimer == nil { c.readDeadlineTimer = time.NewTimer(time.Hour) } - c.readDeadlineCh = updateTimer(c.readDeadlineTimer, deadline) + readDeadlineCh := updateTimer(c.readDeadlineTimer, deadline) + c.readDeadlineChLock.Lock() + c.readDeadlineCh = readDeadlineCh + c.readDeadlineChLock.Unlock() return nil }