Skip to content

Commit

Permalink
Fix unexpected panic when calling Do of a PipelineClient (#997)
Browse files Browse the repository at this point in the history
* fix: Unexpected panic for PipelineClient

PipelineClient would panic when calling `Do` with a nil Response as
the second parm

This commit fixes the unexpected panic by checking nil first before
setting fields for Response

* Add tests to ensure nil resp is valid for PipelineClient
  • Loading branch information
blanet authored Mar 17, 2021
1 parent 0cd7349 commit 860c345
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
16 changes: 4 additions & 12 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1443,8 +1443,7 @@ var (
"Make sure the server returns 'Connection: close' response header before closing the connection")
)

type timeoutError struct {
}
type timeoutError struct{}

func (e *timeoutError) Error() string {
return "timeout"
Expand All @@ -1458,10 +1457,8 @@ func (e *timeoutError) Timeout() bool {
return true
}

var (
// ErrTimeout is returned from timed out calls.
ErrTimeout = &timeoutError{}
)
// ErrTimeout is returned from timed out calls.
var ErrTimeout = &timeoutError{}

// SetMaxConns sets up the maximum number of connections which may be established to all hosts listed in Addr.
func (c *HostClient) SetMaxConns(newMaxConns int) {
Expand Down Expand Up @@ -1571,7 +1568,6 @@ func (c *HostClient) queueForIdle(w *wantConn) {

func (c *HostClient) dialConnFor(w *wantConn) {
conn, err := c.dialHostHard()

if err != nil {
w.tryDeliver(nil, err)
c.decConnsCount()
Expand Down Expand Up @@ -1690,7 +1686,6 @@ func (c *HostClient) decConnsCount() {
if !dialed {
c.connsCount--
}

}

// ConnsCount returns connection count of HostClient
Expand Down Expand Up @@ -2078,15 +2073,13 @@ func (q *wantConnQueue) popFront() *wantConn {

// peekFront returns the wantConn at the front of the queue without removing it.
func (q *wantConnQueue) peekFront() *wantConn {

if q.headPos < len(q.head) {
return q.head[q.headPos]
}
if len(q.tail) > 0 {
return q.tail[0]
}
return nil

}

// cleanFront pops any wantConns that are no longer waiting from the head of the
Expand Down Expand Up @@ -2389,8 +2382,6 @@ func (c *PipelineClient) Do(req *Request, resp *Response) error {
func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
c.init()

resp.Header.disableNormalizing = c.DisableHeaderNamesNormalizing

if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
Expand All @@ -2403,6 +2394,7 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
w := acquirePipelineWork(&c.workPool, 0)
w.req = req
if resp != nil {
resp.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
w.resp = resp
} else {
w.resp = &w.respCopy
Expand Down
32 changes: 28 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,34 @@ func TestClientNilResp(t *testing.T) {
}
}

func TestPipelineClientNilResp(t *testing.T) {
t.Parallel()

ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}
if err := c.DoTimeout(req, nil, time.Second); err != nil {
t.Fatal(err)
}
if err := c.DoDeadline(req, nil, time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
}

func TestClientParseConn(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -337,7 +365,6 @@ func TestClientParseConn(t *testing.T) {
if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) {
t.Fatalf("res LocalAddr addr match fail: %s, hope match: %s", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$")
}

}

func TestClientPostArgs(t *testing.T) {
Expand Down Expand Up @@ -419,7 +446,6 @@ func TestClientRedirectSameSchema(t *testing.T) {
t.Fatalf("HostClient error code response %d", statusCode)
return
}

}

func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) {
Expand Down Expand Up @@ -2601,7 +2627,6 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}

}()
}
wg.Wait()
Expand Down Expand Up @@ -2694,7 +2719,6 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}

}()
}
wg.Wait()
Expand Down

0 comments on commit 860c345

Please sign in to comment.