From 629ee45a73aee7fefeee2f93d269cd53435a10aa Mon Sep 17 00:00:00 2001 From: commit-master Date: Tue, 28 Aug 2018 17:31:04 +0200 Subject: [PATCH] [cors] add OPTIONS status code + fix function typo --- cors.go | 30 ++++++++++++++++++++++++------ cors_test.go | 19 +++++++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/cors.go b/cors.go index 1acf80d..882bcf6 100644 --- a/cors.go +++ b/cors.go @@ -19,14 +19,16 @@ type cors struct { maxAge int ignoreOptions bool allowCredentials bool + optionStatusCode int } // OriginValidator takes an origin string and returns whether or not that origin is allowed. type OriginValidator func(string) bool var ( - defaultCorsMethods = []string{"GET", "HEAD", "POST"} - defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} + defaultCorsOptionStatusCode = 200 + defaultCorsMethods = []string{"GET", "HEAD", "POST"} + defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) ) @@ -130,6 +132,7 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set(corsAllowOriginHeader, returnOrigin) if r.Method == corsOptionMethod { + w.WriteHeader(ch.optionStatusCode) return } ch.h.ServeHTTP(w, r) @@ -164,9 +167,10 @@ func CORS(opts ...CORSOption) func(http.Handler) http.Handler { func parseCORSOptions(opts ...CORSOption) *cors { ch := &cors{ - allowedMethods: defaultCorsMethods, - allowedHeaders: defaultCorsHeaders, - allowedOrigins: []string{}, + allowedMethods: defaultCorsMethods, + allowedHeaders: defaultCorsHeaders, + allowedOrigins: []string{}, + optionStatusCode: defaultCorsOptionStatusCode, } for _, option := range opts { @@ -251,7 +255,21 @@ func AllowedOriginValidator(fn OriginValidator) CORSOption { } } -// ExposeHeaders can be used to specify headers that are available +// OptionStatusCode sets a custom status code on the OPTIONS requests. +// Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory +// and can be used if you need a custom status code (i.e 204) or have a monitoring +// middleware between your server and the CORS middleware. +// +// More informations on the spec: +// https://fetch.spec.whatwg.org/#cors-preflight-fetch +func OptionStatusCode(code int) CORSOption { + return func(ch *cors) error { + ch.optionStatusCode = code + return nil + } +} + +// ExposedHeaders can be used to specify headers that are available // and will not be stripped out by the user-agent. func ExposedHeaders(headers []string) CORSOption { return func(ch *cors) error { diff --git a/cors_test.go b/cors_test.go index a797be7..ad50e51 100644 --- a/cors_test.go +++ b/cors_test.go @@ -122,6 +122,25 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { } } +func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) { + statusCode := 204 + r := newRequest("OPTIONS", "http://www.example.com/") + r.Header.Set("Origin", r.URL.String()) + r.Header.Set(corsRequestMethodHeader, "GET") + + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("Options request must not be passed to next handler") + }) + + CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r) + + if status := rr.Code; status != statusCode { + t.Fatalf("bad status: got %v want %v", status, http.StatusOK) + } +} + func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { r := newRequest("OPTIONS", "http://www.example.com/") r.Header.Set("Origin", r.URL.String())