diff --git a/CHANGELOG.md b/CHANGELOG.md index a77bbdda9..190dc4eba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Unreleased changes are available as `coupergateway/couper:edge` container. * Selecting of appropriate [error handler](https://docs.couper.io/configuration/block/error_handler) in two cases ([#753](https://github.com/coupergateway/couper/pull/753)) * Storing of digit-starting string object keys in [request context](https://docs.couper.io/configuration/variables#request) and of digit-starting string header field names in [request](https://docs.couper.io/configuration/variables#request) variable ([#799](https://github.com/coupergateway/couper/pull/799)) * Use of boolean values for the `headers` attribute or [modifiers](https://docs.couper.io/configuration/modifiers) ([#805](https://github.com/coupergateway/couper/pull/805)) + * Duplicate [CORS](https://docs.couper.io/configuration/block/cors) response headers (with backend sending CORS response headers, too) ([#804](https://github.com/coupergateway/couper/pull/804)) * **Dependencies** * build with go 1.21 ([#800](https://github.com/coupergateway/couper/pull/800)) diff --git a/handler/endpoint.go b/handler/endpoint.go index cc84b0579..9939d464d 100644 --- a/handler/endpoint.go +++ b/handler/endpoint.go @@ -163,7 +163,7 @@ func (e *Endpoint) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - w.AddModifier(httpCtx, e.modifier...) + w.AddModifier(e.modifier...) rw = w } diff --git a/handler/file.go b/handler/file.go index c5993b23a..a6a9aaf09 100644 --- a/handler/file.go +++ b/handler/file.go @@ -12,7 +12,6 @@ import ( "github.com/coupergateway/couper/config/runtime/server" "github.com/coupergateway/couper/errors" - "github.com/coupergateway/couper/eval" "github.com/coupergateway/couper/server/writer" "github.com/coupergateway/couper/utils" ) @@ -82,8 +81,7 @@ func (f *File) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if r, ok := rw.(*writer.Response); ok { - evalContext := eval.ContextFromRequest(req) - r.AddModifier(evalContext.HCLContext(), f.modifier...) + r.AddModifier(f.modifier...) } http.ServeContent(rw, req, reqPath, info.ModTime(), file) @@ -97,8 +95,7 @@ func (f *File) serveDirectory(reqPath string, rw http.ResponseWriter, req *http. if !strings.HasSuffix(reqPath, "/") { if r, ok := rw.(*writer.Response); ok { - evalContext := eval.ContextFromRequest(req) - r.AddModifier(evalContext.HCLContext(), f.modifier...) + r.AddModifier(f.modifier...) } rw.Header().Set("Location", utils.JoinPath(req.URL.Path, "/")) @@ -116,8 +113,7 @@ func (f *File) serveDirectory(reqPath string, rw http.ResponseWriter, req *http. defer file.Close() if r, ok := rw.(*writer.Response); ok { - evalContext := eval.ContextFromRequest(req) - r.AddModifier(evalContext.HCLContext(), f.modifier...) + r.AddModifier(f.modifier...) } http.ServeContent(rw, req, reqPath, info.ModTime(), file) diff --git a/handler/health_test.go b/handler/health_test.go index 2347adec6..c09e5e17c 100644 --- a/handler/health_test.go +++ b/handler/health_test.go @@ -8,9 +8,8 @@ import ( "reflect" "testing" - "github.com/coupergateway/couper/server/writer" - "github.com/coupergateway/couper/handler" + "github.com/coupergateway/couper/server/writer" ) func TestHealth_ServeHTTP(t *testing.T) { diff --git a/handler/middleware/cors.go b/handler/middleware/cors.go index 7a05f5066..b591225f2 100644 --- a/handler/middleware/cors.go +++ b/handler/middleware/cors.go @@ -10,6 +10,7 @@ import ( "github.com/coupergateway/couper/config" "github.com/coupergateway/couper/errors" "github.com/coupergateway/couper/internal/seetie" + "github.com/coupergateway/couper/server/writer" ) var _ http.Handler = &CORS{} @@ -79,7 +80,11 @@ func NewCORSHandler(opts *CORSOptions, nextHandler http.Handler) http.Handler { } func (c *CORS) ServeNextHTTP(rw http.ResponseWriter, nextHandler http.Handler, req *http.Request) { - c.setCorsRespHeaders(rw.Header(), req) + if response, ok := rw.(*writer.Response); ok { + response.AddHeaderModifier(func(header http.Header) { + c.setCorsRespHeaders(header, req) + }) + } if c.isCorsPreflightRequest(req) { rw.WriteHeader(http.StatusNoContent) @@ -100,6 +105,11 @@ func (c *CORS) isCorsPreflightRequest(req *http.Request) bool { } func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) { + headers.Del("Access-Control-Allow-Origin") + headers.Del("Access-Control-Allow-Credentials") + headers.Del("Access-Control-Allow-Headers") + headers.Del("Access-Control-Allow-Methods") + headers.Del("Access-Control-Max-Age") // see https://fetch.spec.whatwg.org/#http-responses allowSpecificOrigin := false if c.options.AllowsOrigin("*") && !c.options.AllowCredentials { diff --git a/handler/middleware/cors_test.go b/handler/middleware/cors_test.go index d83deed54..75904afd7 100644 --- a/handler/middleware/cors_test.go +++ b/handler/middleware/cors_test.go @@ -4,6 +4,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/coupergateway/couper/server/writer" ) func TestCORSOptions_AllowsOrigin(t *testing.T) { @@ -330,7 +332,8 @@ func TestCORS_ServeHTTP(t *testing.T) { } rec := httptest.NewRecorder() - corsHandler.ServeHTTP(rec, req) + r := writer.NewResponseWriter(rec, "") + corsHandler.ServeHTTP(r, req) if !rec.Flushed { rec.Flush() @@ -547,8 +550,8 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) { } rec := httptest.NewRecorder() - - corsHandler.ServeHTTP(rec, req) + r := writer.NewResponseWriter(rec, "") + corsHandler.ServeHTTP(r, req) if !rec.Flushed { rec.Flush() diff --git a/handler/proxy.go b/handler/proxy.go index 12f17c5f6..cace2196a 100644 --- a/handler/proxy.go +++ b/handler/proxy.go @@ -195,10 +195,9 @@ func (p *Proxy) registerWebsocketsResponse(req *http.Request) error { } wsBody := p.getWebsocketsBody() - evalCtx := eval.ContextFromRequest(req) if rw, ok := req.Context().Value(request.ResponseWriter).(*writer.Response); ok { - rw.AddModifier(evalCtx.HCLContextSync(), wsBody, p.context) + rw.AddModifier(wsBody, p.context) } return nil diff --git a/handler/spa.go b/handler/spa.go index da08fe9f1..0acc99c53 100644 --- a/handler/spa.go +++ b/handler/spa.go @@ -18,7 +18,6 @@ import ( "github.com/coupergateway/couper/config" "github.com/coupergateway/couper/config/runtime/server" "github.com/coupergateway/couper/errors" - "github.com/coupergateway/couper/eval" "github.com/coupergateway/couper/server/writer" ) @@ -78,9 +77,8 @@ func (s *Spa) ServeHTTP(rw http.ResponseWriter, req *http.Request) { var content io.ReadSeeker var modTime time.Time - evalContext := eval.ContextFromRequest(req) if r, ok := rw.(*writer.Response); ok { - r.AddModifier(evalContext.HCLContext(), s.modifier...) + r.AddModifier(s.modifier...) } if l := len(s.bootstrapContent); l > 0 { diff --git a/server/http.go b/server/http.go index 82652c627..325b9cb66 100644 --- a/server/http.go +++ b/server/http.go @@ -21,6 +21,7 @@ import ( "github.com/coupergateway/couper/handler" "github.com/coupergateway/couper/handler/middleware" "github.com/coupergateway/couper/logging" + "github.com/coupergateway/couper/server/writer" "github.com/coupergateway/couper/telemetry/instrumentation" "github.com/coupergateway/couper/telemetry/provider" ) @@ -309,7 +310,11 @@ func (s *HTTPServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // due to the middleware callee stack we have to update the 'req' value. *req = *req.WithContext(s.evalCtx.WithClientRequest(req.WithContext(ctx))) - h.ServeHTTP(rw, req) + w := rw + if respW, is := rw.(*writer.Response); is { + w = respW.WithEvalContext(eval.ContextFromRequest(req)) + } + h.ServeHTTP(w, req) } func (s *HTTPServer) setGetBody(h http.Handler, req *http.Request) (opt buffer.Option, err error) { diff --git a/server/http_integration_test.go b/server/http_integration_test.go index ad07ab46a..da0075a18 100644 --- a/server/http_integration_test.go +++ b/server/http_integration_test.go @@ -4611,6 +4611,7 @@ func TestCORS_Configuration(t *testing.T) { acam, acamExists := res.Header["Access-Control-Allow-Methods"] acah, acahExists := res.Header["Access-Control-Allow-Headers"] acac, acacExists := res.Header["Access-Control-Allow-Credentials"] + acax, acaxExists := res.Header["Access-Control-Max-Age"] if tc.expAllowed { if !acaoExists || acao[0] != tc.origin { subT.Errorf("Expected allowed origin, got: %v", acao) @@ -4624,6 +4625,9 @@ func TestCORS_Configuration(t *testing.T) { if !acacExists || acac[0] != "true" { subT.Errorf("Expected allowed credentials, got: %v", acac) } + if !acaxExists || acax[0] != "200" { + subT.Errorf("Expected max-age 200, got: %v", acax) + } } else { if acaoExists { subT.Errorf("Expected not allowed origin, got: %v", acao) @@ -4634,6 +4638,9 @@ func TestCORS_Configuration(t *testing.T) { if acahExists { subT.Errorf("Expected not allowed headers, got: %v", acah) } + if acaxExists { + subT.Errorf("Expected not max-age, got: %v", acax) + } if acacExists { subT.Errorf("Expected not allowed credentials, got: %v", acac) } @@ -4660,6 +4667,9 @@ func TestCORS_Configuration(t *testing.T) { acao, acaoExists = res.Header["Access-Control-Allow-Origin"] acac, acacExists = res.Header["Access-Control-Allow-Credentials"] + acam, acamExists = res.Header["Access-Control-Allow-Methods"] + acah, acahExists = res.Header["Access-Control-Allow-Headers"] + acax, acaxExists = res.Header["Access-Control-Max-Age"] if tc.expAllowed { if !acaoExists || acao[0] != tc.origin { subT.Errorf("Expected allowed origin, got: %v", acao) @@ -4675,6 +4685,15 @@ func TestCORS_Configuration(t *testing.T) { subT.Errorf("Expected not allowed credentials, got: %v", acac) } } + if acamExists { + subT.Errorf("Expected not allowed methods, got: %v", acam) + } + if acahExists { + subT.Errorf("Expected not allowed headers, got: %v", acah) + } + if acaxExists { + subT.Errorf("Expected not max-age, got: %v", acax) + } vary, varyExists = res.Header["Vary"] if !varyExists || strings.Join(vary, ",") != tc.expVary { subT.Errorf("Expected vary %q, got: %q", tc.expVary, strings.Join(vary, ",")) @@ -4698,6 +4717,9 @@ func TestCORS_Configuration(t *testing.T) { acao, acaoExists = res.Header["Access-Control-Allow-Origin"] acac, acacExists = res.Header["Access-Control-Allow-Credentials"] + acam, acamExists = res.Header["Access-Control-Allow-Methods"] + acah, acahExists = res.Header["Access-Control-Allow-Headers"] + acax, acaxExists = res.Header["Access-Control-Max-Age"] if tc.expAllowed { if !acaoExists || acao[0] != tc.origin { subT.Errorf("Expected allowed origin, got: %v", acao) @@ -4713,6 +4735,15 @@ func TestCORS_Configuration(t *testing.T) { subT.Errorf("Expected not allowed credentials, got: %v", acac) } } + if acamExists { + subT.Errorf("Expected not allowed methods, got: %v", acam) + } + if acahExists { + subT.Errorf("Expected not allowed headers, got: %v", acah) + } + if acaxExists { + subT.Errorf("Expected not max-age, got: %v", acax) + } vary, varyExists = res.Header["Vary"] if !varyExists || strings.Join(vary, ",") != tc.expVaryCred { subT.Errorf("Expected vary %q, got: %q", tc.expVaryCred, strings.Join(vary, ",")) diff --git a/server/testdata/integration/config/06_couper.hcl b/server/testdata/integration/config/06_couper.hcl index a04b86018..1264bc215 100644 --- a/server/testdata/integration/config/06_couper.hcl +++ b/server/testdata/integration/config/06_couper.hcl @@ -6,6 +6,7 @@ server "cors" { cors { allowed_origins = "a.com" allow_credentials = true + max_age = "200s" } } @@ -15,6 +16,7 @@ server "cors" { cors { allowed_origins = "b.com" allow_credentials = true + max_age = "200s" } } @@ -23,9 +25,18 @@ server "cors" { cors { allowed_origins = "c.com" allow_credentials = true + max_age = "200s" } endpoint "/" { - response {} + response { + headers = { + access-control-allow-origin = "foo" + access-control-allow-credentials = "bar" + access-control-allow-methods = "BREW" + access-control-allow-headers = "Auth" + access-control-max-age = 300 + } + } } } } diff --git a/server/writer/response.go b/server/writer/response.go index f0653f090..84da3d945 100644 --- a/server/writer/response.go +++ b/server/writer/response.go @@ -22,7 +22,8 @@ type writer interface { } type modifier interface { - AddModifier(*hcl.EvalContext, ...hcl.Body) + AddModifier(...hcl.Body) + AddHeaderModifier(HeaderModifier) } var ( @@ -34,6 +35,8 @@ var ( endOfLine = []byte("\r\n") ) +type HeaderModifier func(header http.Header) + // Response wraps the http.ResponseWriter. type Response struct { hijackedConn net.Conn @@ -45,14 +48,15 @@ type Response struct { statusCode int rawBytesWritten int bytesWritten int - // modifier - evalCtx *hcl.EvalContext - modifier []hcl.Body + // modifiers + evalCtx *eval.Context + modifiers []hcl.Body + headerModifiers []HeaderModifier // security addPrivateCC bool } -// NewResponseWriter creates a new Response object. +// NewResponseWriter creates a new ResponseWriter. It wraps the http.ResponseWriter. func NewResponseWriter(rw http.ResponseWriter, secureCookies string) *Response { return &Response{ rw: rw, @@ -60,6 +64,12 @@ func NewResponseWriter(rw http.ResponseWriter, secureCookies string) *Response { } } +// WithEvalContext sets the eval context for the response modifiers. +func (r *Response) WithEvalContext(ctx *eval.Context) *Response { + r.evalCtx = ctx + return r +} + // Header wraps the Header method of the . func (r *Response) Header() http.Header { return r.rw.Header() @@ -135,9 +145,10 @@ func (r *Response) WriteHeader(statusCode int) { } r.configureHeader() - r.applyModifier() + r.applyHeaderModifiers() + r.applyModifiers() // hcl body modifiers - // !!! Execute after modifier !!! + // execute after modifiers if r.addPrivateCC { r.Header().Add("Cache-Control", "private") } @@ -193,17 +204,29 @@ func (r *Response) AddPrivateCC() { r.addPrivateCC = true } -func (r *Response) AddModifier(evalCtx *hcl.EvalContext, modifier ...hcl.Body) { - r.evalCtx = evalCtx - r.modifier = append(r.modifier, modifier...) +func (r *Response) AddModifier(modifier ...hcl.Body) { + r.modifiers = append(r.modifiers, modifier...) } -func (r *Response) applyModifier() { - if r.evalCtx == nil || r.modifier == nil { +// applyModifiers applies the hcl body modifiers to the response. +func (r *Response) applyModifiers() { + if r.evalCtx == nil || r.modifiers == nil { return } - for _, body := range r.modifier { - _ = eval.ApplyResponseHeaderOps(r.evalCtx, body, r.Header()) + hctx := r.evalCtx.HCLContextSync() + for _, body := range r.modifiers { + _ = eval.ApplyResponseHeaderOps(hctx, body, r.Header()) + } +} + +func (r *Response) AddHeaderModifier(headerModifier HeaderModifier) { + r.headerModifiers = append(r.headerModifiers, headerModifier) +} + +// applyHeaderModifiers applies the http.Header modifiers to the response. +func (r *Response) applyHeaderModifiers() { + for _, modifierFn := range r.headerModifiers { + modifierFn(r.Header()) } }