diff --git a/cors.go b/cors.go index 39565ff..16efb8a 100644 --- a/cors.go +++ b/cors.go @@ -48,7 +48,10 @@ const ( func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get(corsOriginHeader) if !ch.isOriginAllowed(origin) { - ch.h.ServeHTTP(w, r) + if r.Method != corsOptionMethod || ch.ignoreOptions { + ch.h.ServeHTTP(w, r) + } + return } diff --git a/cors_test.go b/cors_test.go index b042d10..a797be7 100644 --- a/cors_test.go +++ b/cors_test.go @@ -122,6 +122,24 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { } } +func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { + r := newRequest("OPTIONS", "http://www.example.com/") + r.Header.Set("Origin", r.URL.String()) + r.Header.Set(corsRequestMethodHeader, "GET") + + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("Options request must not be passed to next handler") + }) + + CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r) + + if status := rr.Code; status != http.StatusOK { + t.Fatalf("bad status: got %v want %v", status, http.StatusOK) + } +} + func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { r := newRequest("OPTIONS", "http://www.example.com/") r.Header.Set("Origin", r.URL.String())