diff --git a/middleware/proxy.go b/middleware/proxy.go index 74f49de8a..e4f98d9ed 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -29,6 +29,33 @@ type ( // Required. Balancer ProxyBalancer + // RetryCount defines the number of times a failed proxied request should be retried + // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. + RetryCount int + + // RetryFilter defines a function used to determine if a failed request to a + // ProxyTarget should be retried. The RetryFilter will only be called when the number + // of previous retries is less than RetryCount. If the function returns true, the + // request will be retried. The provided error indicates the reason for the request + // failure. When the ProxyTarget is unavailable, the error will be an instance of + // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error + // will indicate an internal error in the Proxy middleware. When a RetryFilter is not + // specified, all requests that fail with http.StatusBadGateway will be retried. A custom + // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is + // only called when the request to the target fails, or an internal error in the Proxy + // middleware has occurred. Successful requests that return a non-200 response code cannot + // be retried. + RetryFilter func(c echo.Context, e error) bool + + // ErrorHandler defines a function which can be used to return custom errors from + // the Proxy middleware. ErrorHandler is only invoked when there has been + // either an internal error in the Proxy middleware or the ProxyTarget is + // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked + // when a ProxyTarget returns a non-200 response. In these cases, the response + // is already written so errors cannot be modified. ErrorHandler is only + // invoked after all retry attempts have been exhausted. + ErrorHandler func(c echo.Context, err error) error + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. // Examples: @@ -71,7 +98,8 @@ type ( Next(echo.Context) *ProxyTarget } - // TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target. + // TargetProvider defines an interface that gives the opportunity for balancer + // to return custom errors when selecting target. TargetProvider interface { NextTarget(echo.Context) (*ProxyTarget, error) } @@ -107,14 +135,14 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { - c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL)) return } defer in.Close() out, err := net.Dial("tcp", t.URL.Host) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) return } defer out.Close() @@ -122,7 +150,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { // Write header err = r.Write(out) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL))) return } @@ -136,7 +164,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { go cp(in, out) err = <-errCh if err != nil && err != io.EOF { - c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL)) } }) } @@ -200,7 +228,12 @@ func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { return b.targets[b.random.Intn(len(b.targets))] } -// Next returns an upstream target using round-robin technique. +// Next returns an upstream target using round-robin technique. In the case +// where a previously failed request is being retried, the round-robin +// balancer will attempt to use the next target relative to the original +// request. If the list of targets held by the balancer is modified while a +// failed request is being retried, it is possible that the balancer will +// return the original failed target. // // Note: `nil` is returned in case upstream target list is empty. func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { @@ -211,13 +244,29 @@ func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { } else if len(b.targets) == 1 { return b.targets[0] } - // reset the index if out of bounds - if b.i >= len(b.targets) { - b.i = 0 + + var i int + const lastIdxKey = "_round_robin_last_index" + // This request is a retry, start from the index of the previous + // target to ensure we don't attempt to retry the request with + // the same failed target + if c.Get(lastIdxKey) != nil { + i = c.Get(lastIdxKey).(int) + i++ + if i >= len(b.targets) { + i = 0 + } + } else { + // This is a first time request, use the global index + if b.i >= len(b.targets) { + b.i = 0 + } + i = b.i + b.i++ } - t := b.targets[b.i] - b.i++ - return t + + c.Set(lastIdxKey, i) + return b.targets[i] } // Proxy returns a Proxy middleware. @@ -232,14 +281,26 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { // ProxyWithConfig returns a Proxy middleware with config. // See: `Proxy()` func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { + if config.Balancer == nil { + panic("echo: proxy middleware requires balancer") + } // Defaults if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } - if config.Balancer == nil { - panic("echo: proxy middleware requires balancer") + if config.RetryFilter == nil { + config.RetryFilter = func(c echo.Context, e error) bool { + if httpErr, ok := e.(*echo.HTTPError); ok { + return httpErr.Code == http.StatusBadGateway + } + return false + } + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(c echo.Context, err error) error { + return err + } } - if config.Rewrite != nil { if config.RegexRewrite == nil { config.RegexRewrite = make(map[*regexp.Regexp]string) @@ -250,28 +311,17 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } provider, isTargetProvider := config.Balancer.(TargetProvider) + return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() - - var tgt *ProxyTarget - if isTargetProvider { - tgt, err = provider.NextTarget(c) - if err != nil { - return err - } - } else { - tgt = config.Balancer.Next(c) - } - c.Set(config.ContextKey, tgt) - if err := rewriteURL(config.RegexRewrite, req); err != nil { - return err + return config.ErrorHandler(c, err) } // Fix header @@ -287,19 +337,49 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) } - // Proxy - switch { - case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) - case req.Header.Get(echo.HeaderAccept) == "text/event-stream": - default: - proxyHTTP(tgt, c, config).ServeHTTP(res, req) - } - if e, ok := c.Get("_error").(error); ok { - err = e - } + retries := config.RetryCount + for { + var tgt *ProxyTarget + var err error + if isTargetProvider { + tgt, err = provider.NextTarget(c) + if err != nil { + return config.ErrorHandler(c, err) + } + } else { + tgt = config.Balancer.Next(c) + } - return + c.Set(config.ContextKey, tgt) + + //If retrying a failed request, clear any previous errors from + //context here so that balancers have the option to check for + //errors that occurred using previous target + if retries < config.RetryCount { + c.Set("_error", nil) + } + + // Proxy + switch { + case c.IsWebSocket(): + proxyRaw(tgt, c).ServeHTTP(res, req) + case req.Header.Get(echo.HeaderAccept) == "text/event-stream": + default: + proxyHTTP(tgt, c, config).ServeHTTP(res, req) + } + + err, hasError := c.Get("_error").(error) + if !hasError { + return nil + } + + retry := retries > 0 && config.RetryFilter(c, err) + if !retry { + return config.ErrorHandler(c, err) + } + + retries-- + } } } } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 122dddeba..1b5ba6cbe 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -393,6 +394,321 @@ func TestProxyError(t *testing.T) { assert.Equal(t, http.StatusBadGateway, rec.Code) } +func TestProxyRetries(t *testing.T) { + + newServer := func(res int) (*url.URL, *httptest.Server) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(res) + }), + ) + targetURL, _ := url.Parse(server.URL) + return targetURL, server + } + + targetURL, server := newServer(http.StatusOK) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: targetURL, + } + + targetURL, server = newServer(http.StatusBadRequest) + defer server.Close() + goodTargetWith40X := &ProxyTarget{ + Name: "Good with 40X", + URL: targetURL, + } + + targetURL, _ = url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: targetURL, + } + + alwaysRetryFilter := func(c echo.Context, e error) bool { return true } + neverRetryFilter := func(c echo.Context, e error) bool { return false } + + testCases := []struct { + name string + retryCount int + retryFilters []func(c echo.Context, e error) bool + targets []*ProxyTarget + expectedResponse int + }{ + { + name: "retry count 0 does not attempt retry on fail", + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 1 does not attempt retry on success", + retryCount: 1, + targets: []*ProxyTarget{ + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does retry on handler return true", + retryCount: 1, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does not retry on handler return false", + retryCount: 1, + retryFilters: []func(c echo.Context, e error) bool{ + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when no more retries left", + retryCount: 2, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as only 2 retries + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when retries left but handler returns false", + retryCount: 3, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as retry handler returns false on 2nd check + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 3 succeeds", + retryCount: 3, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "40x responses are not retried", + retryCount: 1, + targets: []*ProxyTarget{ + goodTargetWith40X, + goodTarget, + }, + expectedResponse: http.StatusBadRequest, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + retryFilterCall := 0 + retryFilter := func(c echo.Context, e error) bool { + if len(tc.retryFilters) == 0 { + assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall)) + } + + retryFilterCall++ + + nextRetryFilter := tc.retryFilters[0] + tc.retryFilters = tc.retryFilters[1:] + + return nextRetryFilter(c, e) + } + + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer(tc.targets), + RetryCount: tc.retryCount, + RetryFilter: retryFilter, + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedResponse, rec.Code) + if len(tc.retryFilters) > 0 { + assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters))) + } + }) + } +} + +func TestProxyRetryWithBackendTimeout(t *testing.T) { + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.ResponseHeaderTimeout = time.Millisecond * 500 + + timeoutBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.WriteHeader(404) + }), + ) + defer timeoutBackend.Close() + + timeoutTargetURL, _ := url.Parse(timeoutBackend.URL) + goodBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }), + ) + defer goodBackend.Close() + + goodTargetURL, _ := url.Parse(goodBackend.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Transport: transport, + Balancer: NewRoundRobinBalancer([]*ProxyTarget{ + { + Name: "Timeout", + URL: timeoutTargetURL, + }, + { + Name: "Good", + URL: goodTargetURL, + }, + }), + RetryCount: 1, + }, + )) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, 200, rec.Code) + }() + } + + wg.Wait() + +} + +func TestProxyErrorHandler(t *testing.T) { + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + goodURL, _ := url.Parse(server.URL) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: goodURL, + } + + badURL, _ := url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: badURL, + } + + transformedError := errors.New("a new error") + + testCases := []struct { + name string + target *ProxyTarget + errorHandler func(c echo.Context, e error) error + expectFinalError func(t *testing.T, err error) + }{ + { + name: "Error handler not invoked when request success", + target: goodTarget, + errorHandler: func(c echo.Context, e error) error { + assert.FailNow(t, "error handler should not be invoked") + return e + }, + }, + { + name: "Error handler invoked when request fails", + target: badTarget, + errorHandler: func(c echo.Context, e error) error { + httpErr, ok := e.(*echo.HTTPError) + assert.True(t, ok, "expected http error to be passed to handler") + assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") + return transformedError + }, + expectFinalError: func(t *testing.T, err error) { + assert.Equal(t, transformedError, err, "transformed error not returned from proxy") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}), + ErrorHandler: tc.errorHandler, + }, + )) + + errorHandlerCalled := false + e.HTTPErrorHandler = func(err error, c echo.Context) { + errorHandlerCalled = true + tc.expectFinalError(t, err) + e.DefaultHTTPErrorHandler(err, c) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + if !errorHandlerCalled && tc.expectFinalError != nil { + t.Fatalf("error handler was not called") + } + + }) + } +} + func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { var timeoutStop sync.WaitGroup timeoutStop.Add(1)