Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CORS behaviour #173

Merged
merged 5 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions handler/middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ func NewCORSOptions(cors *config.CORS) (*CORSOptions, error) {
}, nil
}

// NeedsVary if a request with not allowed origin is ignored.
func (c *CORSOptions) NeedsVary() bool {
return !c.AllowsOrigin("*")
}

func (c *CORSOptions) AllowsOrigin(origin string) bool {
if c == nil {
return false
Expand Down Expand Up @@ -100,21 +95,29 @@ func (c *CORS) isCorsPreflightRequest(req *http.Request) bool {

func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
if !c.isCorsRequest(req) {
headers.Add("Vary", "Origin")
return
}

requestOrigin := req.Header.Get("Origin")
if !c.options.AllowsOrigin(requestOrigin) {
headers.Add("Vary", "Origin")
return
}

// see https://fetch.spec.whatwg.org/#http-responses
if c.options.AllowsOrigin("*") && !c.isCredentialed(req.Header) {
if !c.options.AllowsOrigin("*") {
headers.Set("Access-Control-Allow-Origin", requestOrigin)
headers.Add("Vary", "Origin")
} else if !c.options.AllowCredentials {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
} else if requestOrigin != "" {
headers.Set("Access-Control-Allow-Origin", requestOrigin)
}

if c.options.AllowCredentials == true {
if c.options.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
headers.Add("Vary", "Origin")
}

if c.isCorsPreflightRequest(req) {
Expand All @@ -131,15 +134,9 @@ func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
if c.options.MaxAge != "" {
headers.Set("Access-Control-Max-Age", c.options.MaxAge)
}
} else if c.options.NeedsVary() {
headers.Add("Vary", "Origin")
}
}

func (c *CORS) isCorsRequest(req *http.Request) bool {
return req.Header.Get("Origin") != ""
}

func (c *CORS) isCredentialed(headers http.Header) bool {
return headers.Get("Cookie") != "" || headers.Get("Authorization") != "" || headers.Get("Proxy-Authorization") != ""
}
197 changes: 85 additions & 112 deletions handler/middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,6 @@ import (
"testing"
)

func TestCORSOptions_NeedsVary(t *testing.T) {
tests := []struct {
name string
corsOptions *CORSOptions
exp bool
}{
{
"any origin",
&CORSOptions{AllowedOrigins: []string{"*"}},
false,
},
{
"one specific origin",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
true,
},
{
"several specific origins",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com", "http://www.another.host.com"}},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(subT *testing.T) {
needed := tt.corsOptions.NeedsVary()
if needed != tt.exp {
subT.Errorf("Expected %t, got: %t", tt.exp, needed)
}
})
}
}

func TestCORSOptions_AllowsOrigin(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -199,50 +167,6 @@ func TestCORSOptions_isCorsPreflightRequest(t *testing.T) {
}
}

func TestCORS_IsCredentialed(t *testing.T) {
type testCase struct {
name string
requestHeaders map[string]string
exp bool
}

tests := []testCase{
{
"Cookie",
map[string]string{"Cookie": "a=b"},
true,
},
{
"Authorization",
map[string]string{"Authorization": "Basic qeinbqtpoib"},
true,
},
{
"Proxy-Authorization",
map[string]string{"Proxy-Authorization": "Basic qeinbqtpoib"},
true,
},
{
"Not credentialed",
map[string]string{},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://1.2.3.4/", nil)
for name, value := range tt.requestHeaders {
req.Header.Set(name, value)
}

credentialed := NewCORSHandler(nil, nil).(*CORS).isCredentialed(req.Header)
if credentialed != tt.exp {
t.Errorf("expected: %t, got: %t", tt.exp, credentialed)
}
})
}
}

func TestCORS_ServeHTTP(t *testing.T) {
upstreamHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Content-Type", "text/plain")
Expand All @@ -259,6 +183,16 @@ func TestCORS_ServeHTTP(t *testing.T) {
requestHeaders map[string]string
expectedResponseHeaders map[string]string
}{
{
"non-CORS",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
map[string]string{},
map[string]string{
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Credentials": "",
"Vary": "Origin",
},
},
{
"specific origin",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
Expand Down Expand Up @@ -308,7 +242,7 @@ func TestCORS_ServeHTTP(t *testing.T) {
},
},
{
"specific origin, cookie credentials",
"specific origin, credentials",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true},
map[string]string{
"Origin": "https://www.example.com",
Expand All @@ -321,20 +255,7 @@ func TestCORS_ServeHTTP(t *testing.T) {
},
},
{
"specific origin, auth credentials",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true},
map[string]string{
"Origin": "https://www.example.com",
"Authorization": "Basic oertnbin",
},
map[string]string{
"Access-Control-Allow-Origin": "https://www.example.com",
"Access-Control-Allow-Credentials": "true",
"Vary": "Origin",
},
},
{
"any origin, cookie credentials",
"any origin, credentials",
&CORSOptions{AllowedOrigins: []string{"*"}, AllowCredentials: true},
map[string]string{
"Origin": "https://www.example.com",
Expand All @@ -343,33 +264,32 @@ func TestCORS_ServeHTTP(t *testing.T) {
map[string]string{
"Access-Control-Allow-Origin": "https://www.example.com",
"Access-Control-Allow-Credentials": "true",
"Vary": "",
"Vary": "Origin",
},
},
{
"any origin, auth credentials",
&CORSOptions{AllowedOrigins: []string{"*"}, AllowCredentials: true},
"origin mismatch",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
map[string]string{
"Origin": "https://www.example.com",
"Authorization": "Basic oertnbin",
"Origin": "https://www.example.org",
},
map[string]string{
"Access-Control-Allow-Origin": "https://www.example.com",
"Access-Control-Allow-Credentials": "true",
"Vary": "",
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Credentials": "",
"Vary": "Origin",
},
},
{
"any origin, proxy auth credentials",
&CORSOptions{AllowedOrigins: []string{"*"}, AllowCredentials: true},
"origin mismatch, credentials",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true},
map[string]string{
"Origin": "https://www.example.com",
"Proxy-Authorization": "Basic oertnbin",
"Origin": "https://www.example.org",
"Cookie": "a=b",
},
map[string]string{
"Access-Control-Allow-Origin": "https://www.example.com",
"Access-Control-Allow-Credentials": "true",
"Vary": "",
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Credentials": "",
"Vary": "Origin",
},
},
}
Expand Down Expand Up @@ -424,7 +344,7 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
expectedResponseHeaders map[string]string
}{
{
"with ACRM",
"specific origin, with ACRM",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
map[string]string{
"Origin": "https://www.example.com",
Expand All @@ -436,10 +356,11 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Vary": "Origin",
},
},
{
"with ACRH",
"specific origin, with ACRH",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
map[string]string{
"Origin": "https://www.example.com",
Expand All @@ -451,10 +372,11 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
"Access-Control-Allow-Headers": "X-Foo, X-Bar",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Vary": "Origin",
},
},
{
"with ACRM, ACRH",
"specific origin, with ACRM, ACRH",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}},
map[string]string{
"Origin": "https://www.example.com",
Expand All @@ -467,10 +389,11 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
"Access-Control-Allow-Headers": "X-Foo, X-Bar",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Vary": "Origin",
},
},
{
"with ACRM, credentials",
"specific origin, with ACRM, credentials",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true},
map[string]string{
"Origin": "https://www.example.com",
Expand All @@ -482,10 +405,11 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "",
"Vary": "Origin",
},
},
{
"with ACRM, max-age",
"specific origin, with ACRM, max-age",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, MaxAge: "3600"},
map[string]string{
"Origin": "https://www.example.com",
Expand All @@ -497,6 +421,39 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "3600",
"Vary": "Origin",
},
},
{
"any origin, with ACRM",
&CORSOptions{AllowedOrigins: []string{"*"}},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "POST",
},
map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Vary": "",
},
},
{
"any origin, with ACRM, credentials",
&CORSOptions{AllowedOrigins: []string{"*"}, AllowCredentials: true},
map[string]string{
"Origin": "https://www.example.com",
"Access-Control-Request-Method": "POST",
},
map[string]string{
"Access-Control-Allow-Origin": "https://www.example.com",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "",
"Vary": "Origin",
},
},
{
Expand All @@ -512,6 +469,23 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Vary": "Origin",
},
},
{
"origin mismatch, credentials",
&CORSOptions{AllowedOrigins: []string{"https://www.example.com"}, AllowCredentials: true},
map[string]string{
"Origin": "https://www.example.org",
"Access-Control-Request-Method": "POST",
},
map[string]string{
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Vary": "Origin",
},
},
}
Expand All @@ -534,7 +508,6 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {

res := rec.Result()

tt.expectedResponseHeaders["Vary"] = ""
tt.expectedResponseHeaders["Content-Type"] = ""

for name, expValue := range tt.expectedResponseHeaders {
Expand Down