Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CORS wildcard origin and allow credentials #2405

Merged
merged 2 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ type (
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
AllowCredentials bool `yaml:"allow_credentials"`

// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
//
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
//
// Optional. Default value is false.
UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"`

// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
// defines a list of headers that clients are allowed to access.
//
Expand Down Expand Up @@ -203,7 +212,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
allowOrigin = origin
break
}
Expand Down
282 changes: 183 additions & 99 deletions middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,106 +11,190 @@ import (
)

func TestCORS(t *testing.T) {
e := echo.New()
var testCases = []struct {
name string
givenMW echo.MiddlewareFunc
whenMethod string
whenHeaders map[string]string
expectHeaders map[string]string
notExpectHeaders map[string]string
}{
{
name: "ok, wildcard origin",
whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"},
},
{
name: "ok, wildcard AllowedOrigin with no Origin header in request",
notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""},
},
{
name: "ok, specific AllowOrigins and AllowCredentials",
givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
}),
whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "localhost",
echo.HeaderAccessControlAllowCredentials: "true",
},
},
{
name: "ok, preflight request with matching origin for `AllowOrigins`",
givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
}),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "localhost",
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
echo.HeaderAccessControlAllowCredentials: "true",
echo.HeaderAccessControlMaxAge: "3600",
},
},
{
name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: true,
MaxAge: 3600,
}),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "*", // Note: browsers will ignore and complain about responses having `*`
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
echo.HeaderAccessControlAllowCredentials: "true",
echo.HeaderAccessControlMaxAge: "3600",
},
},
{
name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false",
givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: false, // important for this testcase
MaxAge: 3600,
}),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "*",
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
echo.HeaderAccessControlMaxAge: "3600",
},
notExpectHeaders: map[string]string{
echo.HeaderAccessControlAllowCredentials: "",
},
},
{
name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: true,
UnsafeWildcardOriginWithAllowCredentials: true, // important for this testcase
MaxAge: 3600,
}),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "localhost", // This could end up as cross-origin attack
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
echo.HeaderAccessControlAllowCredentials: "true",
echo.HeaderAccessControlMaxAge: "3600",
},
},
{
name: "ok, preflight request with Access-Control-Request-Headers",
givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
}),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
echo.HeaderAccessControlRequestHeaders: "Special-Request-Header",
},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "*",
echo.HeaderAccessControlAllowHeaders: "Special-Request-Header",
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
},
},
{
name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *",
givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"},
}),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"},
},
{
name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *",
givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"},
}),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()

mw := CORS()
if tc.givenMW != nil {
mw = tc.givenMW
}
h := mw(func(c echo.Context) error {
return nil
})

// Wildcard origin
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))

// Wildcard AllowedOrigin with no Origin header in request
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler)
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)

// Allow origins
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))

// Preflight request
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))

// Preflight request with `AllowOrigins` *
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))

// Preflight request with Access-Control-Request-Headers
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))

// Preflight request with `AllowOrigins` which allow all subdomains with *
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com")
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"},
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))

req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com")
h(c)
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
method := http.MethodGet
if tc.whenMethod != "" {
method = tc.whenMethod
}
req := httptest.NewRequest(method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
for k, v := range tc.whenHeaders {
req.Header.Set(k, v)
}

err := h(c)

assert.NoError(t, err)
header := rec.Header()
for k, v := range tc.expectHeaders {
assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v)
}
for k, v := range tc.notExpectHeaders {
if v == "" {
assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k)
} else {
assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v)
}
}
})
}
}

func Test_allowOriginScheme(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
config.CookieSecure = true
}

extractors, err := CreateExtractors(config.TokenLookup)
if err != nil {
panic(err)
extractors, cErr := CreateExtractors(config.TokenLookup)
if cErr != nil {
panic(cErr)
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
Expand Down
6 changes: 3 additions & 3 deletions middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
config.ParseTokenFunc = config.defaultParseToken
}

extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
if err != nil {
panic(err)
extractors, cErr := createExtractors(config.TokenLookup, config.AuthScheme)
if cErr != nil {
panic(cErr)
}
if len(config.TokenLookupFuncs) > 0 {
extractors = append(config.TokenLookupFuncs, extractors...)
Expand Down
6 changes: 3 additions & 3 deletions middleware/key_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
panic("echo: key-auth middleware requires a validator function")
}

extractors, err := createExtractors(config.KeyLookup, config.AuthScheme)
if err != nil {
panic(err)
extractors, cErr := createExtractors(config.KeyLookup, config.AuthScheme)
if cErr != nil {
panic(cErr)
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
Expand Down