diff --git a/pkg/util/proxy/upgradeaware.go b/pkg/util/proxy/upgradeaware.go index f56c17ca3..a3a14241c 100644 --- a/pkg/util/proxy/upgradeaware.go +++ b/pkg/util/proxy/upgradeaware.go @@ -83,6 +83,8 @@ type UpgradeAwareHandler struct { MaxBytesPerSec int64 // Responder is passed errors that occur while setting up proxying. Responder ErrorResponder + // Reject to forward redirect response + RejectForwardingRedirects bool } const defaultFlushInterval = 200 * time.Millisecond @@ -257,6 +259,31 @@ func (h *UpgradeAwareHandler) ServeHTTP(w http.ResponseWriter, req *http.Request proxy.Transport = h.Transport proxy.FlushInterval = h.FlushInterval proxy.ErrorLog = log.New(noSuppressPanicError{}, "", log.LstdFlags) + if h.RejectForwardingRedirects { + oldModifyResponse := proxy.ModifyResponse + proxy.ModifyResponse = func(response *http.Response) error { + code := response.StatusCode + if code >= 300 && code <= 399 { + // close the original response + response.Body.Close() + msg := "the backend attempted to redirect this request, which is not permitted" + // replace the response + *response = http.Response{ + StatusCode: http.StatusBadGateway, + Status: fmt.Sprintf("%d %s", response.StatusCode, http.StatusText(response.StatusCode)), + Body: io.NopCloser(strings.NewReader(msg)), + ContentLength: int64(len(msg)), + } + } else { + if oldModifyResponse != nil { + if err := oldModifyResponse(response); err != nil { + return err + } + } + } + return nil + } + } if h.Responder != nil { // if an optional error interceptor/responder was provided wire it // the custom responder might be used for providing a unified error reporting diff --git a/pkg/util/proxy/upgradeaware_test.go b/pkg/util/proxy/upgradeaware_test.go index f7fcff7c0..6a3a21d8b 100644 --- a/pkg/util/proxy/upgradeaware_test.go +++ b/pkg/util/proxy/upgradeaware_test.go @@ -704,6 +704,83 @@ func TestProxyUpgradeErrorResponse(t *testing.T) { } } +func TestRejectForwardingRedirectsOption(t *testing.T) { + originalBody := []byte(`some data`) + testCases := []struct { + name string + rejectForwardingRedirects bool + serverStatusCode int + expectStatusCode int + expectBody []byte + }{ + { + name: "reject redirection enabled in proxy, backend server sending 200 response", + rejectForwardingRedirects: true, + serverStatusCode: 200, + expectStatusCode: 200, + expectBody: originalBody, + }, + { + name: "reject redirection enabled in proxy, backend server sending 301 response", + rejectForwardingRedirects: true, + serverStatusCode: 301, + expectStatusCode: 502, + expectBody: []byte(`the backend attempted to redirect this request, which is not permitted`), + }, + { + name: "reject redirection disabled in proxy, backend server sending 200 response", + rejectForwardingRedirects: false, + serverStatusCode: 200, + expectStatusCode: 200, + expectBody: originalBody, + }, + { + name: "reject redirection disabled in proxy, backend server sending 301 response", + rejectForwardingRedirects: false, + serverStatusCode: 301, + expectStatusCode: 301, + expectBody: originalBody, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set up a backend server + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.serverStatusCode) + w.Write(originalBody) + })) + defer backendServer.Close() + backendServerURL, _ := url.Parse(backendServer.URL) + + // Set up a proxy pointing to the backend + proxyHandler := NewUpgradeAwareHandler(backendServerURL, nil, false, false, &fakeResponder{t: t}) + proxyHandler.RejectForwardingRedirects = tc.rejectForwardingRedirects + proxy := httptest.NewServer(proxyHandler) + defer proxy.Close() + proxyURL, _ := url.Parse(proxy.URL) + + conn, err := net.Dial("tcp", proxyURL.Host) + require.NoError(t, err) + bufferedReader := bufio.NewReader(conn) + + req, _ := http.NewRequest("GET", proxyURL.String(), nil) + require.NoError(t, req.Write(conn)) + // Verify we get the correct response and message body content + resp, err := http.ReadResponse(bufferedReader, nil) + require.NoError(t, err) + assert.Equal(t, tc.expectStatusCode, resp.StatusCode) + data, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, tc.expectBody, data) + assert.Equal(t, int64(len(tc.expectBody)), resp.ContentLength) + resp.Body.Close() + + // clean up + conn.Close() + }) + } +} + func TestDefaultProxyTransport(t *testing.T) { tests := []struct { name,