Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cors] add OPTIONS status code + fix function typo #132

Merged
merged 1 commit into from
Sep 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ type cors struct {
maxAge int
ignoreOptions bool
allowCredentials bool
optionStatusCode int
}

// OriginValidator takes an origin string and returns whether or not that origin is allowed.
type OriginValidator func(string) bool

var (
defaultCorsMethods = []string{"GET", "HEAD", "POST"}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
defaultCorsOptionStatusCode = 200
defaultCorsMethods = []string{"GET", "HEAD", "POST"}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
)

Expand Down Expand Up @@ -130,6 +132,7 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set(corsAllowOriginHeader, returnOrigin)

if r.Method == corsOptionMethod {
w.WriteHeader(ch.optionStatusCode)
return
}
ch.h.ServeHTTP(w, r)
Expand Down Expand Up @@ -164,9 +167,10 @@ func CORS(opts ...CORSOption) func(http.Handler) http.Handler {

func parseCORSOptions(opts ...CORSOption) *cors {
ch := &cors{
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{},
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{},
optionStatusCode: defaultCorsOptionStatusCode,
}

for _, option := range opts {
Expand Down Expand Up @@ -251,7 +255,20 @@ func AllowedOriginValidator(fn OriginValidator) CORSOption {
}
}

// ExposeHeaders can be used to specify headers that are available
// OptionStatusCode sets a custom status code on the OPTIONS requests.
// Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
// and can be used if you need a custom status code (i.e 204).
//
// More informations on the spec:
// https://fetch.spec.whatwg.org/#cors-preflight-fetch
func OptionStatusCode(code int) CORSOption {
return func(ch *cors) error {
ch.optionStatusCode = code
return nil
}
}

// ExposedHeaders can be used to specify headers that are available
// and will not be stripped out by the user-agent.
func ExposedHeaders(headers []string) CORSOption {
return func(ch *cors) error {
Expand Down
19 changes: 19 additions & 0 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,25 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) {
}
}

func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) {
statusCode := 204
r := newRequest("OPTIONS", "http://www.example.com/")
r.Header.Set("Origin", r.URL.String())
r.Header.Set(corsRequestMethodHeader, "GET")

rr := httptest.NewRecorder()

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("Options request must not be passed to next handler")
})

CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r)

if status := rr.Code; status != statusCode {
t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
}
}

func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) {
r := newRequest("OPTIONS", "http://www.example.com/")
r.Header.Set("Origin", r.URL.String())
Expand Down