Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor CORS handler #28587

Merged
merged 6 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions custom/conf/app.example.ini
Original file line number Diff line number Diff line change
Expand Up @@ -1158,15 +1158,9 @@ LEVEL = Info
;; enable cors headers (disabled by default)
;ENABLED = false
;;
;; scheme of allowed requests
;SCHEME = http
;;
;; list of requesting domains that are allowed
;; list of requesting origins that are allowed, eg: "https://*.example.com"
;ALLOW_DOMAIN = *
;;
;; allow subdomains of headers listed above to request
;ALLOW_SUBDOMAIN = false
;;
;; list of methods allowed to request
;METHODS = GET,HEAD,POST,PUT,PATCH,DELETE,OPTIONS
;;
Expand Down
4 changes: 1 addition & 3 deletions docs/content/administration/config-cheat-sheet.en-us.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ The following configuration set `Content-Type: application/vnd.android.package-a
## CORS (`cors`)

- `ENABLED`: **false**: enable cors headers (disabled by default)
- `SCHEME`: **http**: scheme of allowed requests
- `ALLOW_DOMAIN`: **\***: list of requesting domains that are allowed
- `ALLOW_SUBDOMAIN`: **false**: allow subdomains of headers listed above to request
- `ALLOW_DOMAIN`: **\***: list of requesting origins that are allowed, eg: "https://*.example.com"
- `METHODS`: **GET,HEAD,POST,PUT,PATCH,DELETE,OPTIONS**: list of methods allowed to request
- `MAX_AGE`: **10m**: max time to cache response
- `ALLOW_CREDENTIALS`: **false**: allow request with credentials
Expand Down
2 changes: 0 additions & 2 deletions docs/content/administration/config-cheat-sheet.zh-cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ menu:
## 跨域 (`cors`)

- `ENABLED`: **false**: 启用 CORS 头部(默认禁用)
- `SCHEME`: **http**: 允许请求的协议
- `ALLOW_DOMAIN`: **\***: 允许请求的域名列表
- `ALLOW_SUBDOMAIN`: **false**: 允许上述列出的头部的子域名发出请求。
- `METHODS`: **GET,HEAD,POST,PUT,PATCH,DELETE,OPTIONS**: 允许发起的请求方式列表
- `MAX_AGE`: **10m**: 缓存响应的最大时间
- `ALLOW_CREDENTIALS`: **false**: 允许带有凭据的请求
Expand Down
2 changes: 1 addition & 1 deletion modules/public/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func FileHandlerFunc() http.HandlerFunc {
assetFS := AssetFS()
return func(resp http.ResponseWriter, req *http.Request) {
if req.Method != "GET" && req.Method != "HEAD" {
resp.WriteHeader(http.StatusNotFound)
resp.WriteHeader(http.StatusMethodNotAllowed)
return
}
handleRequest(resp, req, assetFS, req.URL.Path)
Expand Down
4 changes: 1 addition & 3 deletions modules/setting/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ import (
// CORSConfig defines CORS settings
var CORSConfig = struct {
Enabled bool
Scheme string
AllowDomain []string
AllowSubdomain bool
AllowDomain []string // FIXME: this option is from legacy code, it actually works as "AllowedOrigins". When refactoring in the future, the config option should also be renamed together.
Methods []string
MaxAge time.Duration
AllowCredentials bool
Expand Down
24 changes: 6 additions & 18 deletions modules/web/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,18 @@ func (r *Route) wrapMiddlewareAndHandler(h []any) ([]func(http.Handler) http.Han
return middlewares, handlerFunc
}

func (r *Route) Methods(method, pattern string, h ...any) {
// Methods adds the same handlers for multiple http "methods" (separated by ",").
// If any method is invalid, the lower level router will panic.
func (r *Route) Methods(methods, pattern string, h ...any) {
middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h)
fullPattern := r.getPattern(pattern)
if strings.Contains(method, ",") {
methods := strings.Split(method, ",")
if strings.Contains(methods, ",") {
methods := strings.Split(methods, ",")
for _, method := range methods {
r.R.With(middlewares...).Method(strings.TrimSpace(method), fullPattern, handlerFunc)
}
} else {
r.R.With(middlewares...).Method(method, fullPattern, handlerFunc)
r.R.With(middlewares...).Method(methods, fullPattern, handlerFunc)
}
}

Expand All @@ -136,20 +138,6 @@ func (r *Route) Get(pattern string, h ...any) {
r.Methods("GET", pattern, h...)
}

func (r *Route) Options(pattern string, h ...any) {
r.Methods("OPTIONS", pattern, h...)
}

// GetOptions delegate get and options method
func (r *Route) GetOptions(pattern string, h ...any) {
r.Methods("GET,OPTIONS", pattern, h...)
}

// PostOptions delegate post and options method
func (r *Route) PostOptions(pattern string, h ...any) {
r.Methods("POST,OPTIONS", pattern, h...)
}

// Head delegate head method
func (r *Route) Head(pattern string, h ...any) {
r.Methods("HEAD", pattern, h...)
Expand Down
4 changes: 1 addition & 3 deletions routers/api/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,7 @@ func Routes() *web.Route {
m.Use(securityHeaders())
if setting.CORSConfig.Enabled {
m.Use(cors.Handler(cors.Options{
// Scheme: setting.CORSConfig.Scheme, // FIXME: the cors middleware needs scheme option
AllowedOrigins: setting.CORSConfig.AllowDomain,
// setting.CORSConfig.AllowSubdomain // FIXME: the cors middleware needs allowSubdomain option
AllowedOrigins: setting.CORSConfig.AllowDomain,
AllowedMethods: setting.CORSConfig.Methods,
AllowCredentials: setting.CORSConfig.AllowCredentials,
AllowedHeaders: append([]string{"Authorization", "X-Gitea-OTP"}, setting.CORSConfig.Headers...),
Expand Down
22 changes: 11 additions & 11 deletions routers/web/githttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ func requireSignIn(ctx *context.Context) {

func gitHTTPRouters(m *web.Route) {
m.Group("", func() {
m.PostOptions("/git-upload-pack", repo.ServiceUploadPack)
m.PostOptions("/git-receive-pack", repo.ServiceReceivePack)
m.GetOptions("/info/refs", repo.GetInfoRefs)
m.GetOptions("/HEAD", repo.GetTextFile("HEAD"))
m.GetOptions("/objects/info/alternates", repo.GetTextFile("objects/info/alternates"))
m.GetOptions("/objects/info/http-alternates", repo.GetTextFile("objects/info/http-alternates"))
m.GetOptions("/objects/info/packs", repo.GetInfoPacks)
m.GetOptions("/objects/info/{file:[^/]*}", repo.GetTextFile(""))
m.GetOptions("/objects/{head:[0-9a-f]{2}}/{hash:[0-9a-f]{38}}", repo.GetLooseObject)
m.GetOptions("/objects/pack/pack-{file:[0-9a-f]{40}}.pack", repo.GetPackFile)
m.GetOptions("/objects/pack/pack-{file:[0-9a-f]{40}}.idx", repo.GetIdxFile)
m.Methods("POST,OPTIONS", "/git-upload-pack", repo.ServiceUploadPack)
m.Methods("POST,OPTIONS", "/git-receive-pack", repo.ServiceReceivePack)
m.Methods("GET,OPTIONS", "/info/refs", repo.GetInfoRefs)
m.Methods("GET,OPTIONS", "/HEAD", repo.GetTextFile("HEAD"))
m.Methods("GET,OPTIONS", "/objects/info/alternates", repo.GetTextFile("objects/info/alternates"))
m.Methods("GET,OPTIONS", "/objects/info/http-alternates", repo.GetTextFile("objects/info/http-alternates"))
m.Methods("GET,OPTIONS", "/objects/info/packs", repo.GetInfoPacks)
m.Methods("GET,OPTIONS", "/objects/info/{file:[^/]*}", repo.GetTextFile(""))
m.Methods("GET,OPTIONS", "/objects/{head:[0-9a-f]{2}}/{hash:[0-9a-f]{38}}", repo.GetLooseObject)
m.Methods("GET,OPTIONS", "/objects/pack/pack-{file:[0-9a-f]{40}}.pack", repo.GetPackFile)
m.Methods("GET,OPTIONS", "/objects/pack/pack-{file:[0-9a-f]{40}}.idx", repo.GetIdxFile)
}, ignSignInAndCsrf, requireSignIn, repo.HTTPGitEnabledHandler, repo.CorsHandler(), context_service.UserAssignmentWeb())
}
4 changes: 0 additions & 4 deletions routers/web/misc/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ func DummyOK(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusOK)
}

func DummyBadRequest(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}

func RobotsTxt(w http.ResponseWriter, req *http.Request) {
robotsTxt := util.FilePathJoinAbs(setting.CustomPath, "public/robots.txt")
if ok, _ := util.IsExist(robotsTxt); !ok {
Expand Down
50 changes: 31 additions & 19 deletions routers/web/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ const (
GzipMinSize = 1400
)

// CorsHandler return a http handler who set CORS options if enabled by config
func CorsHandler() func(next http.Handler) http.Handler {
// optionsCorsHandler return a http handler which sets CORS options if enabled by config, it blocks non-CORS OPTIONS requests.
func optionsCorsHandler() func(next http.Handler) http.Handler {
var corsHandler func(next http.Handler) http.Handler
if setting.CORSConfig.Enabled {
return cors.Handler(cors.Options{
// Scheme: setting.CORSConfig.Scheme, // FIXME: the cors middleware needs scheme option
AllowedOrigins: setting.CORSConfig.AllowDomain,
// setting.CORSConfig.AllowSubdomain // FIXME: the cors middleware needs allowSubdomain option
corsHandler = cors.Handler(cors.Options{
AllowedOrigins: setting.CORSConfig.AllowDomain,
AllowedMethods: setting.CORSConfig.Methods,
AllowCredentials: setting.CORSConfig.AllowCredentials,
AllowedHeaders: setting.CORSConfig.Headers,
Expand All @@ -75,7 +74,23 @@ func CorsHandler() func(next http.Handler) http.Handler {
}

return func(next http.Handler) http.Handler {
return next
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
wxiaoguang marked this conversation as resolved.
Show resolved Hide resolved
if corsHandler != nil && r.Header.Get("Access-Control-Request-Method") != "" {
corsHandler(next).ServeHTTP(w, r)
} else {
// it should explicitly deny OPTIONS requests if CORS handler is not executed, to avoid the next GET/POST handler being incorrectly called by the OPTIONS request
w.WriteHeader(http.StatusMethodNotAllowed)
}
return
}
// for non-OPTIONS requests, call the CORS handler to add some related headers like "Vary"
if corsHandler != nil {
corsHandler(next).ServeHTTP(w, r)
} else {
next.ServeHTTP(w, r)
}
})
}
}

Expand Down Expand Up @@ -218,7 +233,7 @@ func Routes() *web.Route {
routes := web.NewRoute()

routes.Head("/", misc.DummyOK) // for health check - doesn't need to be passed through gzip handler
routes.Methods("GET, HEAD", "/assets/*", CorsHandler(), public.FileHandlerFunc())
routes.Methods("GET, HEAD, OPTIONS", "/assets/*", optionsCorsHandler(), public.FileHandlerFunc())
routes.Methods("GET, HEAD", "/avatars/*", storageHandler(setting.Avatar.Storage, "avatars", storage.Avatars))
routes.Methods("GET, HEAD", "/repo-avatars/*", storageHandler(setting.RepoAvatar.Storage, "repo-avatars", storage.RepoAvatars))
routes.Methods("GET, HEAD", "/apple-touch-icon.png", misc.StaticRedirect("/assets/img/apple-touch-icon.png"))
Expand Down Expand Up @@ -458,8 +473,8 @@ func registerRoutes(m *web.Route) {
m.Get("/change-password", func(ctx *context.Context) {
ctx.Redirect(setting.AppSubURL + "/user/settings/account")
})
m.Any("/*", CorsHandler(), public.FileHandlerFunc())
}, CorsHandler())
m.Methods("GET, HEAD", "/*", public.FileHandlerFunc())
}, optionsCorsHandler())

m.Group("/explore", func() {
m.Get("", func(ctx *context.Context) {
Expand Down Expand Up @@ -532,14 +547,11 @@ func registerRoutes(m *web.Route) {
// TODO manage redirection
m.Post("/authorize", web.Bind(forms.AuthorizationForm{}), auth.AuthorizeOAuth)
}, ignSignInAndCsrf, reqSignIn)
m.Options("/login/oauth/userinfo", CorsHandler(), misc.DummyBadRequest)
m.Get("/login/oauth/userinfo", ignSignInAndCsrf, auth.InfoOAuth)
m.Options("/login/oauth/access_token", CorsHandler(), misc.DummyBadRequest)
m.Post("/login/oauth/access_token", CorsHandler(), web.Bind(forms.AccessTokenForm{}), ignSignInAndCsrf, auth.AccessTokenOAuth)
m.Options("/login/oauth/keys", CorsHandler(), misc.DummyBadRequest)
m.Get("/login/oauth/keys", ignSignInAndCsrf, auth.OIDCKeys)
m.Options("/login/oauth/introspect", CorsHandler(), misc.DummyBadRequest)
m.Post("/login/oauth/introspect", CorsHandler(), web.Bind(forms.IntrospectTokenForm{}), ignSignInAndCsrf, auth.IntrospectOAuth)

m.Methods("GET, OPTIONS", "/login/oauth/userinfo", optionsCorsHandler(), ignSignInAndCsrf, auth.InfoOAuth)
m.Methods("POST, OPTIONS", "/login/oauth/access_token", optionsCorsHandler(), web.Bind(forms.AccessTokenForm{}), ignSignInAndCsrf, auth.AccessTokenOAuth)
m.Methods("GET, OPTIONS", "/login/oauth/keys", optionsCorsHandler(), ignSignInAndCsrf, auth.OIDCKeys)
m.Methods("POST, OPTIONS", "/login/oauth/introspect", optionsCorsHandler(), web.Bind(forms.IntrospectTokenForm{}), ignSignInAndCsrf, auth.IntrospectOAuth)

m.Group("/user/settings", func() {
m.Get("", user_setting.Profile)
Expand Down Expand Up @@ -770,7 +782,7 @@ func registerRoutes(m *web.Route) {

m.Group("", func() {
m.Get("/{username}", user.UsernameSubRoute)
m.Get("/attachments/{uuid}", repo.GetAttachment)
m.Methods("GET, OPTIONS", "/attachments/{uuid}", optionsCorsHandler(), repo.GetAttachment)
}, ignSignIn)

m.Post("/{username}", reqSignIn, context_service.UserAssignmentWeb(), user.Action)
Expand Down
85 changes: 78 additions & 7 deletions tests/integration/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,88 @@ import (
"net/http"
"testing"

"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/test"
"code.gitea.io/gitea/routers"
"code.gitea.io/gitea/tests"

"github.com/stretchr/testify/assert"
)

func TestCORSNotSet(t *testing.T) {
func TestCORS(t *testing.T) {
defer tests.PrepareTestEnv(t)()
req := NewRequest(t, "GET", "/api/v1/version")
session := loginUser(t, "user2")
resp := session.MakeRequest(t, req, http.StatusOK)
assert.Equal(t, resp.Code, http.StatusOK)
corsHeader := resp.Header().Get("Access-Control-Allow-Origin")
assert.Empty(t, corsHeader, "Access-Control-Allow-Origin: generated header should match") // header not set
t.Run("CORS enabled", func(t *testing.T) {
defer test.MockVariableValue(&setting.CORSConfig.Enabled, true)()
defer test.MockVariableValue(&testWebRoutes, routers.NormalRoutes())()

t.Run("API with CORS", func(t *testing.T) {
// GET api with no CORS header
req := NewRequest(t, "GET", "/api/v1/version")
resp := MakeRequest(t, req, http.StatusOK)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Header().Values("Vary"), "Origin")

// OPTIONS api for CORS
req = NewRequest(t, "OPTIONS", "/api/v1/version").
SetHeader("Origin", "https://example.com").
SetHeader("Access-Control-Request-Method", "GET")
resp = MakeRequest(t, req, http.StatusOK)
assert.NotEmpty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Header().Values("Vary"), "Origin")
})

t.Run("Web with CORS", func(t *testing.T) {
// GET userinfo with no CORS header
req := NewRequest(t, "GET", "/login/oauth/userinfo")
resp := MakeRequest(t, req, http.StatusUnauthorized)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Header().Values("Vary"), "Origin")

// OPTIONS userinfo for CORS
req = NewRequest(t, "OPTIONS", "/login/oauth/userinfo").
SetHeader("Origin", "https://example.com").
SetHeader("Access-Control-Request-Method", "GET")
resp = MakeRequest(t, req, http.StatusOK)
assert.NotEmpty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Header().Values("Vary"), "Origin")

// OPTIONS userinfo for non-CORS
req = NewRequest(t, "OPTIONS", "/login/oauth/userinfo")
resp = MakeRequest(t, req, http.StatusMethodNotAllowed)
assert.NotContains(t, resp.Header().Values("Vary"), "Origin")
})
})

t.Run("CORS disabled", func(t *testing.T) {
defer test.MockVariableValue(&setting.CORSConfig.Enabled, false)()
defer test.MockVariableValue(&testWebRoutes, routers.NormalRoutes())()

t.Run("API without CORS", func(t *testing.T) {
req := NewRequest(t, "GET", "/api/v1/version")
resp := MakeRequest(t, req, http.StatusOK)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, resp.Header().Values("Vary"))

req = NewRequest(t, "OPTIONS", "/api/v1/version").
SetHeader("Origin", "https://example.com").
SetHeader("Access-Control-Request-Method", "GET")
resp = MakeRequest(t, req, http.StatusMethodNotAllowed)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, resp.Header().Values("Vary"))
})

t.Run("Web without CORS", func(t *testing.T) {
req := NewRequest(t, "GET", "/login/oauth/userinfo")
resp := MakeRequest(t, req, http.StatusUnauthorized)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.NotContains(t, resp.Header().Values("Vary"), "Origin")

req = NewRequest(t, "OPTIONS", "/login/oauth/userinfo").
SetHeader("Origin", "https://example.com").
SetHeader("Access-Control-Request-Method", "GET")
resp = MakeRequest(t, req, http.StatusMethodNotAllowed)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.NotContains(t, resp.Header().Values("Vary"), "Origin")
})
})
}