Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CORS and access controls #366

Merged
merged 14 commits into from
Nov 9, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Unreleased changes are available as `avenga/couper:edge` container.
* Handling of [`accept_forwarded_url`](./docs/REFERENCE.md#settings-block) "host" if `H-Forwarded-Host` request header field contains a port ([#360](https://github.com/avenga/couper/pull/360))
* Setting `Vary` response header fields for [CORS](./doc/REFERENCE.md#cors-block) ([#362](https://github.com/avenga/couper/pull/362))
* Use of referenced backends in [OAuth2 CC Blocks](./docs/REFERENCE.md#oauth2-cc-block) ([#321](https://github.com/avenga/couper/issues/321))
* [CORS](./doc/REFERENCE.md#cors-block) preflight requests are not blocked by access controls any more ([#366](https://github.com/avenga/couper/pull/366))

---

Expand Down
74 changes: 41 additions & 33 deletions config/runtime/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,27 +152,27 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
return nil, err
}

corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Spa))
if cerr != nil {
return nil, cerr
}
h := middleware.NewCORSHandler(corsOptions, spaHandler)

spaHandler, err = configureProtectedHandler(accessControls, confCtx,
config.NewAccessControl(srvConf.AccessControl, srvConf.DisableAccessControl),
config.NewAccessControl(srvConf.Spa.AccessControl, srvConf.Spa.DisableAccessControl),
&protectedOptions{
epOpts: &handler.EndpointOptions{Error: serverOptions.ServerErrTpl},
handler: h,
handler: spaHandler,
memStore: memStore,
proxyFromEnv: conf.Settings.NoProxyFromEnv,
srvOpts: serverOptions,
}, nil, log)

if err != nil {
return nil, err
}

corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Spa))
if cerr != nil {
return nil, cerr
}

spaHandler = middleware.NewCORSHandler(corsOptions, spaHandler)

for _, spaPath := range srvConf.Spa.Paths {
err = setRoutesFromHosts(serverConfiguration, portsHosts, path.Join(serverOptions.SPABasePath, spaPath), spaHandler, spa)
if err != nil {
Expand All @@ -182,34 +182,37 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
}

if srvConf.Files != nil {
fileHandler, ferr := handler.NewFile(srvConf.Files.DocumentRoot, serverOptions, []hcl.Body{srvConf.Files.Remain, srvConf.Remain})
if ferr != nil {
return nil, ferr
}

corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Files))
if cerr != nil {
return nil, cerr
var (
fileHandler http.Handler
err error
)
fileHandler, err = handler.NewFile(srvConf.Files.DocumentRoot, serverOptions, []hcl.Body{srvConf.Files.Remain, srvConf.Remain})
if err != nil {
return nil, err
}

h := middleware.NewCORSHandler(corsOptions, fileHandler)

protectedFileHandler, err := configureProtectedHandler(accessControls, confCtx,
fileHandler, err = configureProtectedHandler(accessControls, confCtx,
config.NewAccessControl(srvConf.AccessControl, srvConf.DisableAccessControl),
config.NewAccessControl(srvConf.Files.AccessControl, srvConf.Files.DisableAccessControl),
&protectedOptions{
epOpts: &handler.EndpointOptions{Error: serverOptions.FilesErrTpl},
handler: h,
handler: fileHandler,
memStore: memStore,
proxyFromEnv: conf.Settings.NoProxyFromEnv,
srvOpts: serverOptions,
}, nil, log)

if err != nil {
return nil, err
}

err = setRoutesFromHosts(serverConfiguration, portsHosts, serverOptions.FilesBasePath, protectedFileHandler, files)
corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Files))
if cerr != nil {
return nil, cerr
}

fileHandler = middleware.NewCORSHandler(corsOptions, fileHandler)

err = setRoutesFromHosts(serverConfiguration, portsHosts, serverOptions.FilesBasePath, fileHandler, files)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -238,10 +241,6 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
}
endpointPatterns[cleanPattern] = true

corsOptions, err := middleware.NewCORSOptions(whichCORS(srvConf, parentAPI))
if err != nil {
return nil, err
}
epOpts, err := newEndpointOptions(
confCtx, endpointConf, parentAPI, serverOptions,
log, conf.Settings.NoProxyFromEnv, memStore,
Expand All @@ -260,13 +259,13 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
}
epOpts.LogHandlerKind = kind.String()

epHandler := handler.NewEndpoint(epOpts, log, modifier)
protectedHandler := middleware.NewCORSHandler(corsOptions, epHandler)

accessControl := newAC(srvConf, parentAPI)
var epHandler http.Handler
if parentAPI != nil && parentAPI.CatchAllEndpoint == endpointConf {
protectedHandler = epOpts.Error.ServeError(errors.RouteNotFound)
epHandler = epOpts.Error.ServeError(errors.RouteNotFound)
} else {
epHandler = handler.NewEndpoint(epOpts, log, modifier)
}

scopeMaps := []map[string]string{}
if parentAPI != nil {
apiScopeMap, err := seetie.ValueToScopeMap(parentAPI.Scope)
Expand All @@ -281,11 +280,12 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
}
scopeMaps = append(scopeMaps, endpointScopeMap)
scopeControl := ac.NewScopeControl(scopeMaps)
endpointHandlers[endpointConf], err = configureProtectedHandler(accessControls, confCtx, accessControl,
accessControl := newAC(srvConf, parentAPI)
epHandler, err = configureProtectedHandler(accessControls, confCtx, accessControl,
config.NewAccessControl(endpointConf.AccessControl, endpointConf.DisableAccessControl),
&protectedOptions{
epOpts: epOpts,
handler: protectedHandler,
handler: epHandler,
memStore: memStore,
proxyFromEnv: conf.Settings.NoProxyFromEnv,
srvOpts: serverOptions,
Expand All @@ -294,6 +294,14 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
return nil, err
}

corsOptions, err := middleware.NewCORSOptions(whichCORS(srvConf, parentAPI))
if err != nil {
return nil, err
}

epHandler = middleware.NewCORSHandler(corsOptions, epHandler)

endpointHandlers[endpointConf] = epHandler
err = setRoutesFromHosts(serverConfiguration, portsHosts, pattern, endpointHandlers[endpointConf], kind)
if err != nil {
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions handler/middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,13 @@ func (c *CORS) isCorsPreflightRequest(req *http.Request) bool {
}

func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
// see https://fetch.spec.whatwg.org/#http-responses
allowSpecificOrigin := false
if c.options.AllowsOrigin("*") && !c.options.AllowCredentials {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Add("Vary", "Origin")
allowSpecificOrigin = true
}

if !c.isCorsRequest(req) {
Expand All @@ -113,10 +116,7 @@ func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
return
}

// see https://fetch.spec.whatwg.org/#http-responses
if c.options.AllowsOrigin("*") && !c.options.AllowCredentials {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
if allowSpecificOrigin {
headers.Set("Access-Control-Allow-Origin", requestOrigin)
}

Expand Down
142 changes: 129 additions & 13 deletions server/http_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3668,28 +3668,37 @@ func TestCORS_Configuration(t *testing.T) {
shutdown, _ := newCouper("testdata/integration/config/06_couper.hcl", test.New(t))
defer shutdown()

requestMethod := "GET"
requestHeaders := "Authorization"

type testCase struct {
path string
origin string
expAllowedOrigin bool
path string
origin string
expAllowed bool
expAllowedMethods string
expAllowedHeaders string
expVaryPF string
expVary string
expVaryCred string
}

for _, tc := range []testCase{
{"/06_couper.hcl", "a.com", true},
{"/spa/", "b.com", true},
{"/api/", "c.com", true},
{"/06_couper.hcl", "no.com", false},
{"/spa/", "", false},
{"/api/", "no.com", false},
{"/06_couper.hcl", "a.com", true, requestMethod, requestHeaders, "Origin,Access-Control-Request-Method,Access-Control-Request-Headers", "Origin,Accept-Encoding", "Origin,Accept-Encoding"},
{"/spa/", "b.com", true, requestMethod, requestHeaders, "Origin,Access-Control-Request-Method,Access-Control-Request-Headers", "Origin,Accept-Encoding", "Origin,Accept-Encoding"},
{"/api/", "c.com", true, requestMethod, requestHeaders, "Origin,Access-Control-Request-Method,Access-Control-Request-Headers", "Origin,Accept-Encoding", "Origin"},
{"/06_couper.hcl", "no.com", false, "", "", "Origin", "Origin,Accept-Encoding", "Origin,Accept-Encoding"},
{"/spa/", "", false, "", "", "Origin", "Origin,Accept-Encoding", "Origin,Accept-Encoding"},
{"/api/", "no.com", false, "", "", "Origin", "Origin,Accept-Encoding", "Origin"},
} {
t.Run(tc.path[1:], func(subT *testing.T) {
helper := test.New(subT)

// preflight request
req, err := http.NewRequest(http.MethodOptions, "http://localhost:8080"+tc.path, nil)
helper.Must(err)

req.Header.Set("Access-Control-Request-Method", "GET")
req.Header.Set("Access-Control-Request-Headers", "origin")
req.Header.Set("Access-Control-Request-Method", requestMethod)
req.Header.Set("Access-Control-Request-Headers", requestHeaders)
req.Header.Set("Origin", tc.origin)

res, err := client.Do(req)
Expand All @@ -3701,8 +3710,115 @@ func TestCORS_Configuration(t *testing.T) {
subT.Fatalf("%q: expected Status %d, got: %d", tc.path, http.StatusNoContent, res.StatusCode)
}

if val, exist := res.Header["Access-Control-Allow-Origin"]; tc.expAllowedOrigin && (!exist || val[0] != tc.origin) {
subT.Errorf("Expected allowed origin resp, got: %v", val)
acao, acaoExists := res.Header["Access-Control-Allow-Origin"]
acam, acamExists := res.Header["Access-Control-Allow-Methods"]
acah, acahExists := res.Header["Access-Control-Allow-Headers"]
acac, acacExists := res.Header["Access-Control-Allow-Credentials"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
}
if !acamExists || acam[0] != tc.expAllowedMethods {
subT.Errorf("Expected allowed methods, got: %v", acam)
}
if !acahExists || acah[0] != tc.expAllowedHeaders {
subT.Errorf("Expected allowed headers, got: %v", acah)
}
if !acacExists || acac[0] != "true" {
subT.Errorf("Expected allowed credentials, got: %v", acac)
}
} else {
if acaoExists {
subT.Errorf("Expected not allowed origin, got: %v", acao)
}
if acamExists {
subT.Errorf("Expected not allowed methods, got: %v", acam)
}
if acahExists {
subT.Errorf("Expected not allowed headers, got: %v", acah)
}
if acacExists {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
vary, varyExists := res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVaryPF {
subT.Errorf("Expected vary %q, got: %q", tc.expVaryPF, strings.Join(vary, ","))
}

// actual request lacking credentials -> rejected by basic_auth AC
req, err = http.NewRequest(requestMethod, "http://localhost:8080"+tc.path, nil)
helper.Must(err)

req.Header.Set("Origin", tc.origin)

res, err = client.Do(req)
helper.Must(err)

helper.Must(res.Body.Close())

if res.StatusCode != http.StatusUnauthorized {
subT.Fatalf("%q: expected Status %d, got: %d", tc.path, http.StatusUnauthorized, res.StatusCode)
}

acao, acaoExists = res.Header["Access-Control-Allow-Origin"]
acac, acacExists = res.Header["Access-Control-Allow-Credentials"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
}
if !acacExists || acac[0] != "true" {
subT.Errorf("Expected allowed credentials, got: %v", acac)
}
} else {
if acaoExists {
subT.Errorf("Expected not allowed origin, got: %v", acao)
}
if acacExists {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
vary, varyExists = res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVary {
subT.Errorf("Expected vary %q, got: %q", tc.expVary, strings.Join(vary, ","))
}

// actual request with credentials
req, err = http.NewRequest(requestMethod, "http://localhost:8080"+tc.path, nil)
helper.Must(err)

req.Header.Set("Origin", tc.origin)
req.Header.Set("Authorization", "Basic Zm9vOmFzZGY=")

res, err = client.Do(req)
helper.Must(err)

helper.Must(res.Body.Close())

if res.StatusCode != http.StatusOK {
subT.Fatalf("%q: expected Status %d, got: %d", tc.path, http.StatusOK, res.StatusCode)
}

acao, acaoExists = res.Header["Access-Control-Allow-Origin"]
acac, acacExists = res.Header["Access-Control-Allow-Credentials"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
}
if !acacExists || acac[0] != "true" {
subT.Errorf("Expected allowed credentials, got: %v", acac)
}
} else {
if acaoExists {
subT.Errorf("Expected not allowed origin, got: %v", acao)
}
if acacExists {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
vary, varyExists = res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVaryCred {
subT.Errorf("Expected vary %q, got: %q", tc.expVaryCred, strings.Join(vary, ","))
}
})
}
Expand Down
11 changes: 11 additions & 0 deletions server/testdata/integration/config/06_couper.hcl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
server "cors" {
access_control = ["ba"]

files {
document_root = "./"
cors {
allowed_origins = "a.com"
allow_credentials = true
}
}

Expand All @@ -11,16 +14,24 @@ server "cors" {
bootstrap_file = "06_couper.hcl"
cors {
allowed_origins = "b.com"
allow_credentials = true
}
}

api {
base_path = "/api"
cors {
allowed_origins = "c.com"
allow_credentials = true
}
endpoint "/" {
response {}
}
}
}
definitions {
basic_auth "ba" {
user = "foo"
password = "asdf"
}
}