Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

777 duplicate CORS response headers #804

Merged
merged 8 commits into from
Apr 8, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Unreleased changes are available as `coupergateway/couper:edge` container.
* Selecting of appropriate [error handler](https://docs.couper.io/configuration/block/error_handler) in two cases ([#753](https://github.com/coupergateway/couper/pull/753))
* Storing of digit-starting string object keys in [request context](https://docs.couper.io/configuration/variables#request) and of digit-starting string header field names in [request](https://docs.couper.io/configuration/variables#request) variable ([#799](https://github.com/coupergateway/couper/pull/799))
* Use of boolean values for the `headers` attribute or [modifiers](https://docs.couper.io/configuration/modifiers) ([#805](https://github.com/coupergateway/couper/pull/805))
* Duplicate [CORS](https://docs.couper.io/configuration/block/cors) response headers (with backend sending CORS response headers, too) ([#804](https://github.com/coupergateway/couper/pull/804))

* **Dependencies**
* build with go 1.21 ([#800](https://github.com/coupergateway/couper/pull/800))
Expand Down
2 changes: 1 addition & 1 deletion handler/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (e *Endpoint) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

w.AddModifier(httpCtx, e.modifier...)
w.AddModifier(e.modifier...)
rw = w
}

Expand Down
10 changes: 3 additions & 7 deletions handler/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/coupergateway/couper/config/runtime/server"
"github.com/coupergateway/couper/errors"
"github.com/coupergateway/couper/eval"
"github.com/coupergateway/couper/server/writer"
"github.com/coupergateway/couper/utils"
)
Expand Down Expand Up @@ -82,8 +81,7 @@ func (f *File) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}

if r, ok := rw.(*writer.Response); ok {
evalContext := eval.ContextFromRequest(req)
r.AddModifier(evalContext.HCLContext(), f.modifier...)
r.AddModifier(f.modifier...)
}

http.ServeContent(rw, req, reqPath, info.ModTime(), file)
Expand All @@ -97,8 +95,7 @@ func (f *File) serveDirectory(reqPath string, rw http.ResponseWriter, req *http.

if !strings.HasSuffix(reqPath, "/") {
if r, ok := rw.(*writer.Response); ok {
evalContext := eval.ContextFromRequest(req)
r.AddModifier(evalContext.HCLContext(), f.modifier...)
r.AddModifier(f.modifier...)
}

rw.Header().Set("Location", utils.JoinPath(req.URL.Path, "/"))
Expand All @@ -116,8 +113,7 @@ func (f *File) serveDirectory(reqPath string, rw http.ResponseWriter, req *http.
defer file.Close()

if r, ok := rw.(*writer.Response); ok {
evalContext := eval.ContextFromRequest(req)
r.AddModifier(evalContext.HCLContext(), f.modifier...)
r.AddModifier(f.modifier...)
}

http.ServeContent(rw, req, reqPath, info.ModTime(), file)
Expand Down
3 changes: 1 addition & 2 deletions handler/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ import (
"reflect"
"testing"

"github.com/coupergateway/couper/server/writer"

"github.com/coupergateway/couper/handler"
"github.com/coupergateway/couper/server/writer"
)

func TestHealth_ServeHTTP(t *testing.T) {
Expand Down
12 changes: 11 additions & 1 deletion handler/middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/coupergateway/couper/config"
"github.com/coupergateway/couper/errors"
"github.com/coupergateway/couper/internal/seetie"
"github.com/coupergateway/couper/server/writer"
)

var _ http.Handler = &CORS{}
Expand Down Expand Up @@ -79,7 +80,11 @@ func NewCORSHandler(opts *CORSOptions, nextHandler http.Handler) http.Handler {
}

func (c *CORS) ServeNextHTTP(rw http.ResponseWriter, nextHandler http.Handler, req *http.Request) {
c.setCorsRespHeaders(rw.Header(), req)
if response, ok := rw.(*writer.Response); ok {
response.AddHeaderModifier(func(header http.Header) {
c.setCorsRespHeaders(header, req)
})
}

if c.isCorsPreflightRequest(req) {
rw.WriteHeader(http.StatusNoContent)
Expand All @@ -100,6 +105,11 @@ func (c *CORS) isCorsPreflightRequest(req *http.Request) bool {
}

func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
headers.Del("Access-Control-Allow-Origin")
headers.Del("Access-Control-Allow-Credentials")
headers.Del("Access-Control-Allow-Headers")
headers.Del("Access-Control-Allow-Methods")
headers.Del("Access-Control-Max-Age")
// see https://fetch.spec.whatwg.org/#http-responses
allowSpecificOrigin := false
if c.options.AllowsOrigin("*") && !c.options.AllowCredentials {
Expand Down
9 changes: 6 additions & 3 deletions handler/middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"net/http"
"net/http/httptest"
"testing"

"github.com/coupergateway/couper/server/writer"
)

func TestCORSOptions_AllowsOrigin(t *testing.T) {
Expand Down Expand Up @@ -330,7 +332,8 @@ func TestCORS_ServeHTTP(t *testing.T) {
}

rec := httptest.NewRecorder()
corsHandler.ServeHTTP(rec, req)
r := writer.NewResponseWriter(rec, "")
corsHandler.ServeHTTP(r, req)

if !rec.Flushed {
rec.Flush()
Expand Down Expand Up @@ -547,8 +550,8 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
}

rec := httptest.NewRecorder()

corsHandler.ServeHTTP(rec, req)
r := writer.NewResponseWriter(rec, "")
corsHandler.ServeHTTP(r, req)

if !rec.Flushed {
rec.Flush()
Expand Down
3 changes: 1 addition & 2 deletions handler/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ func (p *Proxy) registerWebsocketsResponse(req *http.Request) error {
}

wsBody := p.getWebsocketsBody()
evalCtx := eval.ContextFromRequest(req)

if rw, ok := req.Context().Value(request.ResponseWriter).(*writer.Response); ok {
rw.AddModifier(evalCtx.HCLContextSync(), wsBody, p.context)
rw.AddModifier(wsBody, p.context)
}

return nil
Expand Down
4 changes: 1 addition & 3 deletions handler/spa.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/coupergateway/couper/config"
"github.com/coupergateway/couper/config/runtime/server"
"github.com/coupergateway/couper/errors"
"github.com/coupergateway/couper/eval"
"github.com/coupergateway/couper/server/writer"
)

Expand Down Expand Up @@ -78,9 +77,8 @@ func (s *Spa) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
var content io.ReadSeeker
var modTime time.Time

evalContext := eval.ContextFromRequest(req)
if r, ok := rw.(*writer.Response); ok {
r.AddModifier(evalContext.HCLContext(), s.modifier...)
r.AddModifier(s.modifier...)
}

if l := len(s.bootstrapContent); l > 0 {
Expand Down
7 changes: 6 additions & 1 deletion server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/coupergateway/couper/handler"
"github.com/coupergateway/couper/handler/middleware"
"github.com/coupergateway/couper/logging"
"github.com/coupergateway/couper/server/writer"
"github.com/coupergateway/couper/telemetry/instrumentation"
"github.com/coupergateway/couper/telemetry/provider"
)
Expand Down Expand Up @@ -309,7 +310,11 @@ func (s *HTTPServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// due to the middleware callee stack we have to update the 'req' value.
*req = *req.WithContext(s.evalCtx.WithClientRequest(req.WithContext(ctx)))

h.ServeHTTP(rw, req)
w := rw
if respW, is := rw.(*writer.Response); is {
w = respW.WithEvalContext(eval.ContextFromRequest(req))
}
h.ServeHTTP(w, req)
}

func (s *HTTPServer) setGetBody(h http.Handler, req *http.Request) (opt buffer.Option, err error) {
Expand Down
31 changes: 31 additions & 0 deletions server/http_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4611,6 +4611,7 @@ func TestCORS_Configuration(t *testing.T) {
acam, acamExists := res.Header["Access-Control-Allow-Methods"]
acah, acahExists := res.Header["Access-Control-Allow-Headers"]
acac, acacExists := res.Header["Access-Control-Allow-Credentials"]
acax, acaxExists := res.Header["Access-Control-Max-Age"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
Expand All @@ -4624,6 +4625,9 @@ func TestCORS_Configuration(t *testing.T) {
if !acacExists || acac[0] != "true" {
subT.Errorf("Expected allowed credentials, got: %v", acac)
}
if !acaxExists || acax[0] != "200" {
subT.Errorf("Expected max-age 200, got: %v", acax)
}
} else {
if acaoExists {
subT.Errorf("Expected not allowed origin, got: %v", acao)
Expand All @@ -4634,6 +4638,9 @@ func TestCORS_Configuration(t *testing.T) {
if acahExists {
subT.Errorf("Expected not allowed headers, got: %v", acah)
}
if acaxExists {
subT.Errorf("Expected not max-age, got: %v", acax)
}
if acacExists {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
Expand All @@ -4660,6 +4667,9 @@ func TestCORS_Configuration(t *testing.T) {

acao, acaoExists = res.Header["Access-Control-Allow-Origin"]
acac, acacExists = res.Header["Access-Control-Allow-Credentials"]
acam, acamExists = res.Header["Access-Control-Allow-Methods"]
acah, acahExists = res.Header["Access-Control-Allow-Headers"]
acax, acaxExists = res.Header["Access-Control-Max-Age"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
Expand All @@ -4675,6 +4685,15 @@ func TestCORS_Configuration(t *testing.T) {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
if acamExists {
subT.Errorf("Expected not allowed methods, got: %v", acam)
}
if acahExists {
subT.Errorf("Expected not allowed headers, got: %v", acah)
}
if acaxExists {
subT.Errorf("Expected not max-age, got: %v", acax)
}
vary, varyExists = res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVary {
subT.Errorf("Expected vary %q, got: %q", tc.expVary, strings.Join(vary, ","))
Expand All @@ -4698,6 +4717,9 @@ func TestCORS_Configuration(t *testing.T) {

acao, acaoExists = res.Header["Access-Control-Allow-Origin"]
acac, acacExists = res.Header["Access-Control-Allow-Credentials"]
acam, acamExists = res.Header["Access-Control-Allow-Methods"]
acah, acahExists = res.Header["Access-Control-Allow-Headers"]
acax, acaxExists = res.Header["Access-Control-Max-Age"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
Expand All @@ -4713,6 +4735,15 @@ func TestCORS_Configuration(t *testing.T) {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
if acamExists {
subT.Errorf("Expected not allowed methods, got: %v", acam)
}
if acahExists {
subT.Errorf("Expected not allowed headers, got: %v", acah)
}
if acaxExists {
subT.Errorf("Expected not max-age, got: %v", acax)
}
vary, varyExists = res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVaryCred {
subT.Errorf("Expected vary %q, got: %q", tc.expVaryCred, strings.Join(vary, ","))
Expand Down
13 changes: 12 additions & 1 deletion server/testdata/integration/config/06_couper.hcl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ server "cors" {
cors {
allowed_origins = "a.com"
allow_credentials = true
max_age = "200s"
}
}

Expand All @@ -15,6 +16,7 @@ server "cors" {
cors {
allowed_origins = "b.com"
allow_credentials = true
max_age = "200s"
}
}

Expand All @@ -23,9 +25,18 @@ server "cors" {
cors {
allowed_origins = "c.com"
allow_credentials = true
max_age = "200s"
}
endpoint "/" {
response {}
response {
headers = {
access-control-allow-origin = "foo"
access-control-allow-credentials = "bar"
access-control-allow-methods = "BREW"
access-control-allow-headers = "Auth"
access-control-max-age = 300
}
}
}
}
}
Expand Down
51 changes: 37 additions & 14 deletions server/writer/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ type writer interface {
}

type modifier interface {
AddModifier(*hcl.EvalContext, ...hcl.Body)
AddModifier(...hcl.Body)
AddHeaderModifier(HeaderModifier)
}

var (
Expand All @@ -34,6 +35,8 @@ var (
endOfLine = []byte("\r\n")
)

type HeaderModifier func(header http.Header)

// Response wraps the http.ResponseWriter.
type Response struct {
hijackedConn net.Conn
Expand All @@ -45,21 +48,28 @@ type Response struct {
statusCode int
rawBytesWritten int
bytesWritten int
// modifier
evalCtx *hcl.EvalContext
modifier []hcl.Body
// modifiers
evalCtx *eval.Context
modifiers []hcl.Body
headerModifiers []HeaderModifier
// security
addPrivateCC bool
}

// NewResponseWriter creates a new Response object.
// NewResponseWriter creates a new ResponseWriter. It wraps the http.ResponseWriter.
func NewResponseWriter(rw http.ResponseWriter, secureCookies string) *Response {
return &Response{
rw: rw,
secureCookies: secureCookies,
}
}

// WithEvalContext sets the eval context for the response modifiers.
func (r *Response) WithEvalContext(ctx *eval.Context) *Response {
r.evalCtx = ctx
return r
}

// Header wraps the Header method of the <http.ResponseWriter>.
func (r *Response) Header() http.Header {
return r.rw.Header()
Expand Down Expand Up @@ -135,9 +145,10 @@ func (r *Response) WriteHeader(statusCode int) {
}

r.configureHeader()
r.applyModifier()
r.applyHeaderModifiers()
r.applyModifiers() // hcl body modifiers

// !!! Execute after modifier !!!
// execute after modifiers
if r.addPrivateCC {
r.Header().Add("Cache-Control", "private")
}
Expand Down Expand Up @@ -193,17 +204,29 @@ func (r *Response) AddPrivateCC() {
r.addPrivateCC = true
}

func (r *Response) AddModifier(evalCtx *hcl.EvalContext, modifier ...hcl.Body) {
r.evalCtx = evalCtx
r.modifier = append(r.modifier, modifier...)
func (r *Response) AddModifier(modifier ...hcl.Body) {
r.modifiers = append(r.modifiers, modifier...)
}

func (r *Response) applyModifier() {
if r.evalCtx == nil || r.modifier == nil {
// applyModifiers applies the hcl body modifiers to the response.
func (r *Response) applyModifiers() {
if r.evalCtx == nil || r.modifiers == nil {
return
}

for _, body := range r.modifier {
_ = eval.ApplyResponseHeaderOps(r.evalCtx, body, r.Header())
hctx := r.evalCtx.HCLContextSync()
for _, body := range r.modifiers {
_ = eval.ApplyResponseHeaderOps(hctx, body, r.Header())
}
}

func (r *Response) AddHeaderModifier(headerModifier HeaderModifier) {
r.headerModifiers = append(r.headerModifiers, headerModifier)
}

// applyHeaderModifiers applies the http.Header modifiers to the response.
func (r *Response) applyHeaderModifiers() {
for _, modifierFn := range r.headerModifiers {
modifierFn(r.Header())
}
}
Loading