From ced2d15cdebc59d10bdce66383bfd8b77794bfb7 Mon Sep 17 00:00:00 2001 From: Lance Ivy Date: Fri, 10 Nov 2017 15:46:38 -0800 Subject: [PATCH] distinguish between explicit and implicit star --- cors.go | 24 ++++++++++++++++-------- cors_test.go | 42 ++++++++++++++++++++++-------------------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/cors.go b/cors.go index 1cf7581..39565ff 100644 --- a/cors.go +++ b/cors.go @@ -111,13 +111,17 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { } returnOrigin := origin - for _, o := range ch.allowedOrigins { - // A configuration of * is different than explicitly setting an allowed - // origin. Returning arbitrary origin headers an an access control allow - // origin header is unsafe and is not required by any use case. - if o == corsOriginMatchAll { - returnOrigin = "*" - break + if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { + returnOrigin = "*" + } else { + for _, o := range ch.allowedOrigins { + // A configuration of * is different than explicitly setting an allowed + // origin. Returning arbitrary origin headers an an access control allow + // origin header is unsafe and is not required by any use case. + if o == corsOriginMatchAll { + returnOrigin = "*" + break + } } } w.Header().Set(corsAllowOriginHeader, returnOrigin) @@ -159,7 +163,7 @@ func parseCORSOptions(opts ...CORSOption) *cors { ch := &cors{ allowedMethods: defaultCorsMethods, allowedHeaders: defaultCorsHeaders, - allowedOrigins: []string{corsOriginMatchAll}, + allowedOrigins: []string{}, } for _, option := range opts { @@ -307,6 +311,10 @@ func (ch *cors) isOriginAllowed(origin string) bool { return ch.allowedOriginValidator(origin) } + if len(ch.allowedOrigins) == 0 { + return true + } + for _, allowedOrigin := range ch.allowedOrigins { if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { return true diff --git a/cors_test.go b/cors_test.go index 61eb18f..b042d10 100644 --- a/cors_test.go +++ b/cors_test.go @@ -313,7 +313,7 @@ func TestCORSWithMultipleHandlers(t *testing.T) { } } -func TestCORSHandlerWithCustomValidator(t *testing.T) { +func TestCORSOriginValidatorWithImplicitStar(t *testing.T) { r := newRequest("GET", "http://a.example.com") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() @@ -327,32 +327,20 @@ func TestCORSHandlerWithCustomValidator(t *testing.T) { return false } - // Specially craft a CORS object. - handleFunc := func(h http.Handler) http.Handler { - c := &cors{ - allowedMethods: defaultCorsMethods, - allowedHeaders: defaultCorsHeaders, - allowedOrigins: []string{"http://a.example.com"}, - h: h, - } - AllowedOriginValidator(originValidator)(c) - return c - } - - handleFunc(testHandler).ServeHTTP(rr, r) + CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) header := rr.HeaderMap.Get(corsAllowOriginHeader) if header != r.URL.String() { t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header) } - } -func TestCORSAllowStar(t *testing.T) { +func TestCORSOriginValidatorWithExplicitStar(t *testing.T) { r := newRequest("GET", "http://a.example.com") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + originValidator := func(origin string) bool { if strings.HasSuffix(origin, ".example.com") { return true @@ -360,12 +348,26 @@ func TestCORSAllowStar(t *testing.T) { return false } - CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) + CORS( + AllowedOriginValidator(originValidator), + AllowedOrigins([]string{"*"}), + )(testHandler).ServeHTTP(rr, r) header := rr.HeaderMap.Get(corsAllowOriginHeader) - // Because * is the default CORS policy (which is safe), we should be - // expect a * returned here as the Access Control Allow Origin header if header != "*" { - t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header) + t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header) } +} +func TestCORSAllowStar(t *testing.T) { + r := newRequest("GET", "http://a.example.com") + r.Header.Set("Origin", r.URL.String()) + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + CORS()(testHandler).ServeHTTP(rr, r) + header := rr.HeaderMap.Get(corsAllowOriginHeader) + if header != "*" { + t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header) + } }