Skip to content

Commit

Permalink
client: Implement DoRedirects (#765)
Browse files Browse the repository at this point in the history
This commit adds a `DoRedirects` method to both `HostClient` and
`Client` as well as top level convenience function of the same name that
is called with the package level `defaultClient`.

Re-implementing this redirect logic in user code is harder than
necessary.
  • Loading branch information
tsenart authored Mar 25, 2020
1 parent 38aa88a commit 75c6008
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 15 deletions.
113 changes: 98 additions & 15 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,30 @@ func DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return defaultClient.DoDeadline(req, resp, deadline)
}

// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, &defaultClient)
return err
}

// Get returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
Expand Down Expand Up @@ -372,6 +396,30 @@ func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) er
return clientDoDeadline(req, resp, deadline, c)
}

// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c)
return err
}

// Do performs the given http request and fills the given http response.
//
// Request must contain at least non-zero RequestURI with full url (including
Expand Down Expand Up @@ -731,7 +779,7 @@ type clientDoer interface {
func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
req := AcquireRequest()

statusCode, body, err = doRequestFollowRedirects(req, dst, url, c)
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)

ReleaseRequest(req)
return statusCode, body, err
Expand Down Expand Up @@ -771,7 +819,7 @@ func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDo
// concurrent requests, since timed out requests on client side
// usually continue execution on the host.
go func() {
statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirects(req, dst, url, c)
statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(req, dst, url, c)
ch <- clientURLResponse{
statusCode: statusCodeCopy,
body: bodyCopy,
Expand Down Expand Up @@ -808,29 +856,45 @@ func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (status
}
}

statusCode, body, err = doRequestFollowRedirects(req, dst, url, c)
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)

ReleaseRequest(req)
return statusCode, body, err
}

var (
errMissingLocation = errors.New("missing Location header for http redirect")
errTooManyRedirects = errors.New("too many redirects detected when doing the request")
// ErrMissingLocation is returned by clients when the Location header is missing on
// an HTTP response with a redirect status code.
ErrMissingLocation = errors.New("missing Location header for http redirect")
// ErrTooManyRedirects is returned by clients when the number of redirects followed
// exceed the max count.
ErrTooManyRedirects = errors.New("too many redirects detected when doing the request")
)

const maxRedirectsCount = 16
const defaultMaxRedirectsCount = 16

func doRequestFollowRedirects(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
resp := AcquireResponse()
bodyBuf := resp.bodyBuffer()
resp.keepBodyBuffer = true
oldBody := bodyBuf.B
bodyBuf.B = dst

statusCode, body, err = doRequestFollowRedirects(req, resp, url, defaultMaxRedirectsCount, c)

body = bodyBuf.B
bodyBuf.B = oldBody
resp.keepBodyBuffer = false
ReleaseResponse(resp)

return statusCode, body, err
}

func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer) (statusCode int, body []byte, err error) {
scheme := req.uri.Scheme()
req.schemaUpdate = false

redirectsCount := 0

for {
// In case redirect to different scheme
if redirectsCount > 0 && !bytes.Equal(scheme, req.uri.Scheme()) {
Expand Down Expand Up @@ -859,22 +923,17 @@ func doRequestFollowRedirects(req *Request, dst []byte, url string, c clientDoer

redirectsCount++
if redirectsCount > maxRedirectsCount {
err = errTooManyRedirects
err = ErrTooManyRedirects
break
}
location := resp.Header.peek(strLocation)
if len(location) == 0 {
err = errMissingLocation
err = ErrMissingLocation
break
}
url = getRedirectURL(url, location)
}

body = bodyBuf.B
bodyBuf.B = oldBody
resp.keepBodyBuffer = false
ReleaseResponse(resp)

return statusCode, body, err
}

Expand Down Expand Up @@ -994,6 +1053,30 @@ func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time
return clientDoDeadline(req, resp, deadline, c)
}

// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c)
return err
}

func clientDoTimeout(req *Request, resp *Response, timeout time.Duration, c clientDoer) error {
deadline := time.Now().Add(timeout)
return clientDoDeadline(req, resp, deadline, c)
Expand Down
36 changes: 36 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,42 @@ func TestClientFollowRedirects(t *testing.T) {
t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss")
}
}

for i := 0; i < 10; i++ {
req := AcquireRequest()
resp := AcquireResponse()

req.SetRequestURI("http://xxx/foo")

err := c.DoRedirects(req, resp, 16)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if statusCode := resp.StatusCode(); statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}

if body := string(resp.Body()); body != "/bar" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
}

ReleaseRequest(req)
ReleaseResponse(resp)
}

req := AcquireRequest()
resp := AcquireResponse()

req.SetRequestURI("http://xxx/foo")

err := c.DoRedirects(req, resp, 0)
if have, want := err, ErrTooManyRedirects; have != want {
t.Fatalf("want error: %v, have %v", want, have)
}

ReleaseRequest(req)
ReleaseResponse(resp)
}

func TestClientGetTimeoutSuccess(t *testing.T) {
Expand Down

0 comments on commit 75c6008

Please sign in to comment.