Skip to content

Commit

Permalink
[cors] add OPTIONS status code + fix function typo
Browse files Browse the repository at this point in the history
  • Loading branch information
commit-master committed Sep 11, 2018
1 parent 7e0847f commit 629ee45
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
30 changes: 24 additions & 6 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ type cors struct {
maxAge int
ignoreOptions bool
allowCredentials bool
optionStatusCode int
}

// OriginValidator takes an origin string and returns whether or not that origin is allowed.
type OriginValidator func(string) bool

var (
defaultCorsMethods = []string{"GET", "HEAD", "POST"}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
defaultCorsOptionStatusCode = 200
defaultCorsMethods = []string{"GET", "HEAD", "POST"}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
)

Expand Down Expand Up @@ -130,6 +132,7 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set(corsAllowOriginHeader, returnOrigin)

if r.Method == corsOptionMethod {
w.WriteHeader(ch.optionStatusCode)
return
}
ch.h.ServeHTTP(w, r)
Expand Down Expand Up @@ -164,9 +167,10 @@ func CORS(opts ...CORSOption) func(http.Handler) http.Handler {

func parseCORSOptions(opts ...CORSOption) *cors {
ch := &cors{
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{},
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{},
optionStatusCode: defaultCorsOptionStatusCode,
}

for _, option := range opts {
Expand Down Expand Up @@ -251,7 +255,21 @@ func AllowedOriginValidator(fn OriginValidator) CORSOption {
}
}

// ExposeHeaders can be used to specify headers that are available
// OptionStatusCode sets a custom status code on the OPTIONS requests.
// Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
// and can be used if you need a custom status code (i.e 204) or have a monitoring
// middleware between your server and the CORS middleware.
//
// More informations on the spec:
// https://fetch.spec.whatwg.org/#cors-preflight-fetch
func OptionStatusCode(code int) CORSOption {
return func(ch *cors) error {
ch.optionStatusCode = code
return nil
}
}

// ExposedHeaders can be used to specify headers that are available
// and will not be stripped out by the user-agent.
func ExposedHeaders(headers []string) CORSOption {
return func(ch *cors) error {
Expand Down
19 changes: 19 additions & 0 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,25 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) {
}
}

func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) {
statusCode := 204
r := newRequest("OPTIONS", "http://www.example.com/")
r.Header.Set("Origin", r.URL.String())
r.Header.Set(corsRequestMethodHeader, "GET")

rr := httptest.NewRecorder()

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("Options request must not be passed to next handler")
})

CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r)

if status := rr.Code; status != statusCode {
t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
}
}

func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) {
r := newRequest("OPTIONS", "http://www.example.com/")
r.Header.Set("Origin", r.URL.String())
Expand Down

0 comments on commit 629ee45

Please sign in to comment.