From 13a38d26174b16d5b4bf6f1094c1389ec9879572 Mon Sep 17 00:00:00 2001 From: Matias Anaya Date: Wed, 9 May 2018 05:34:52 +1000 Subject: [PATCH] [bugfix] Handle CORS pre-flight request in middleware (#112) --- cors.go | 5 ++++- cors_test.go | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/cors.go b/cors.go index 39565ff..16efb8a 100644 --- a/cors.go +++ b/cors.go @@ -48,7 +48,10 @@ const ( func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get(corsOriginHeader) if !ch.isOriginAllowed(origin) { - ch.h.ServeHTTP(w, r) + if r.Method != corsOptionMethod || ch.ignoreOptions { + ch.h.ServeHTTP(w, r) + } + return } diff --git a/cors_test.go b/cors_test.go index b042d10..a797be7 100644 --- a/cors_test.go +++ b/cors_test.go @@ -122,6 +122,24 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { } } +func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { + r := newRequest("OPTIONS", "http://www.example.com/") + r.Header.Set("Origin", r.URL.String()) + r.Header.Set(corsRequestMethodHeader, "GET") + + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("Options request must not be passed to next handler") + }) + + CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r) + + if status := rr.Code; status != http.StatusOK { + t.Fatalf("bad status: got %v want %v", status, http.StatusOK) + } +} + func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { r := newRequest("OPTIONS", "http://www.example.com/") r.Header.Set("Origin", r.URL.String())