Skip to content

Commit

Permalink
CORS and access controls
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Schneider authored Nov 9, 2021
2 parents 5487f5d + d5e8f65 commit 208c331
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 50 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Unreleased changes are available as `avenga/couper:edge` container.
* Handling of [`accept_forwarded_url`](./docs/REFERENCE.md#settings-block) "host" if `H-Forwarded-Host` request header field contains a port ([#360](https://github.com/avenga/couper/pull/360))
* Setting `Vary` response header fields for [CORS](./doc/REFERENCE.md#cors-block) ([#362](https://github.com/avenga/couper/pull/362))
* Use of referenced backends in [OAuth2 CC Blocks](./docs/REFERENCE.md#oauth2-cc-block) ([#321](https://github.com/avenga/couper/issues/321))
* [CORS](./doc/REFERENCE.md#cors-block) preflight requests are not blocked by access controls any more ([#366](https://github.com/avenga/couper/pull/366))

---

Expand Down
74 changes: 41 additions & 33 deletions config/runtime/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,27 +152,27 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
return nil, err
}

corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Spa))
if cerr != nil {
return nil, cerr
}
h := middleware.NewCORSHandler(corsOptions, spaHandler)

spaHandler, err = configureProtectedHandler(accessControls, confCtx,
config.NewAccessControl(srvConf.AccessControl, srvConf.DisableAccessControl),
config.NewAccessControl(srvConf.Spa.AccessControl, srvConf.Spa.DisableAccessControl),
&protectedOptions{
epOpts: &handler.EndpointOptions{Error: serverOptions.ServerErrTpl},
handler: h,
handler: spaHandler,
memStore: memStore,
proxyFromEnv: conf.Settings.NoProxyFromEnv,
srvOpts: serverOptions,
}, nil, log)

if err != nil {
return nil, err
}

corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Spa))
if cerr != nil {
return nil, cerr
}

spaHandler = middleware.NewCORSHandler(corsOptions, spaHandler)

for _, spaPath := range srvConf.Spa.Paths {
err = setRoutesFromHosts(serverConfiguration, portsHosts, path.Join(serverOptions.SPABasePath, spaPath), spaHandler, spa)
if err != nil {
Expand All @@ -182,34 +182,37 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
}

if srvConf.Files != nil {
fileHandler, ferr := handler.NewFile(srvConf.Files.DocumentRoot, serverOptions, []hcl.Body{srvConf.Files.Remain, srvConf.Remain})
if ferr != nil {
return nil, ferr
}

corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Files))
if cerr != nil {
return nil, cerr
var (
fileHandler http.Handler
err error
)
fileHandler, err = handler.NewFile(srvConf.Files.DocumentRoot, serverOptions, []hcl.Body{srvConf.Files.Remain, srvConf.Remain})
if err != nil {
return nil, err
}

h := middleware.NewCORSHandler(corsOptions, fileHandler)

protectedFileHandler, err := configureProtectedHandler(accessControls, confCtx,
fileHandler, err = configureProtectedHandler(accessControls, confCtx,
config.NewAccessControl(srvConf.AccessControl, srvConf.DisableAccessControl),
config.NewAccessControl(srvConf.Files.AccessControl, srvConf.Files.DisableAccessControl),
&protectedOptions{
epOpts: &handler.EndpointOptions{Error: serverOptions.FilesErrTpl},
handler: h,
handler: fileHandler,
memStore: memStore,
proxyFromEnv: conf.Settings.NoProxyFromEnv,
srvOpts: serverOptions,
}, nil, log)

if err != nil {
return nil, err
}

err = setRoutesFromHosts(serverConfiguration, portsHosts, serverOptions.FilesBasePath, protectedFileHandler, files)
corsOptions, cerr := middleware.NewCORSOptions(whichCORS(srvConf, srvConf.Files))
if cerr != nil {
return nil, cerr
}

fileHandler = middleware.NewCORSHandler(corsOptions, fileHandler)

err = setRoutesFromHosts(serverConfiguration, portsHosts, serverOptions.FilesBasePath, fileHandler, files)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -238,10 +241,6 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
}
endpointPatterns[cleanPattern] = true

corsOptions, err := middleware.NewCORSOptions(whichCORS(srvConf, parentAPI))
if err != nil {
return nil, err
}
epOpts, err := newEndpointOptions(
confCtx, endpointConf, parentAPI, serverOptions,
log, conf.Settings.NoProxyFromEnv, memStore,
Expand All @@ -260,13 +259,13 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
}
epOpts.LogHandlerKind = kind.String()

epHandler := handler.NewEndpoint(epOpts, log, modifier)
protectedHandler := middleware.NewCORSHandler(corsOptions, epHandler)

accessControl := newAC(srvConf, parentAPI)
var epHandler http.Handler
if parentAPI != nil && parentAPI.CatchAllEndpoint == endpointConf {
protectedHandler = epOpts.Error.ServeError(errors.RouteNotFound)
epHandler = epOpts.Error.ServeError(errors.RouteNotFound)
} else {
epHandler = handler.NewEndpoint(epOpts, log, modifier)
}

scopeMaps := []map[string]string{}
if parentAPI != nil {
apiScopeMap, err := seetie.ValueToScopeMap(parentAPI.Scope)
Expand All @@ -281,11 +280,12 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
}
scopeMaps = append(scopeMaps, endpointScopeMap)
scopeControl := ac.NewScopeControl(scopeMaps)
endpointHandlers[endpointConf], err = configureProtectedHandler(accessControls, confCtx, accessControl,
accessControl := newAC(srvConf, parentAPI)
epHandler, err = configureProtectedHandler(accessControls, confCtx, accessControl,
config.NewAccessControl(endpointConf.AccessControl, endpointConf.DisableAccessControl),
&protectedOptions{
epOpts: epOpts,
handler: protectedHandler,
handler: epHandler,
memStore: memStore,
proxyFromEnv: conf.Settings.NoProxyFromEnv,
srvOpts: serverOptions,
Expand All @@ -294,6 +294,14 @@ func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *ca
return nil, err
}

corsOptions, err := middleware.NewCORSOptions(whichCORS(srvConf, parentAPI))
if err != nil {
return nil, err
}

epHandler = middleware.NewCORSHandler(corsOptions, epHandler)

endpointHandlers[endpointConf] = epHandler
err = setRoutesFromHosts(serverConfiguration, portsHosts, pattern, endpointHandlers[endpointConf], kind)
if err != nil {
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions handler/middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,13 @@ func (c *CORS) isCorsPreflightRequest(req *http.Request) bool {
}

func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
// see https://fetch.spec.whatwg.org/#http-responses
allowSpecificOrigin := false
if c.options.AllowsOrigin("*") && !c.options.AllowCredentials {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Add("Vary", "Origin")
allowSpecificOrigin = true
}

if !c.isCorsRequest(req) {
Expand All @@ -113,10 +116,7 @@ func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
return
}

// see https://fetch.spec.whatwg.org/#http-responses
if c.options.AllowsOrigin("*") && !c.options.AllowCredentials {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
if allowSpecificOrigin {
headers.Set("Access-Control-Allow-Origin", requestOrigin)
}

Expand Down
142 changes: 129 additions & 13 deletions server/http_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3668,28 +3668,37 @@ func TestCORS_Configuration(t *testing.T) {
shutdown, _ := newCouper("testdata/integration/config/06_couper.hcl", test.New(t))
defer shutdown()

requestMethod := "GET"
requestHeaders := "Authorization"

type testCase struct {
path string
origin string
expAllowedOrigin bool
path string
origin string
expAllowed bool
expAllowedMethods string
expAllowedHeaders string
expVaryPF string
expVary string
expVaryCred string
}

for _, tc := range []testCase{
{"/06_couper.hcl", "a.com", true},
{"/spa/", "b.com", true},
{"/api/", "c.com", true},
{"/06_couper.hcl", "no.com", false},
{"/spa/", "", false},
{"/api/", "no.com", false},
{"/06_couper.hcl", "a.com", true, requestMethod, requestHeaders, "Origin,Access-Control-Request-Method,Access-Control-Request-Headers", "Origin,Accept-Encoding", "Origin,Accept-Encoding"},
{"/spa/", "b.com", true, requestMethod, requestHeaders, "Origin,Access-Control-Request-Method,Access-Control-Request-Headers", "Origin,Accept-Encoding", "Origin,Accept-Encoding"},
{"/api/", "c.com", true, requestMethod, requestHeaders, "Origin,Access-Control-Request-Method,Access-Control-Request-Headers", "Origin,Accept-Encoding", "Origin"},
{"/06_couper.hcl", "no.com", false, "", "", "Origin", "Origin,Accept-Encoding", "Origin,Accept-Encoding"},
{"/spa/", "", false, "", "", "Origin", "Origin,Accept-Encoding", "Origin,Accept-Encoding"},
{"/api/", "no.com", false, "", "", "Origin", "Origin,Accept-Encoding", "Origin"},
} {
t.Run(tc.path[1:], func(subT *testing.T) {
helper := test.New(subT)

// preflight request
req, err := http.NewRequest(http.MethodOptions, "http://localhost:8080"+tc.path, nil)
helper.Must(err)

req.Header.Set("Access-Control-Request-Method", "GET")
req.Header.Set("Access-Control-Request-Headers", "origin")
req.Header.Set("Access-Control-Request-Method", requestMethod)
req.Header.Set("Access-Control-Request-Headers", requestHeaders)
req.Header.Set("Origin", tc.origin)

res, err := client.Do(req)
Expand All @@ -3701,8 +3710,115 @@ func TestCORS_Configuration(t *testing.T) {
subT.Fatalf("%q: expected Status %d, got: %d", tc.path, http.StatusNoContent, res.StatusCode)
}

if val, exist := res.Header["Access-Control-Allow-Origin"]; tc.expAllowedOrigin && (!exist || val[0] != tc.origin) {
subT.Errorf("Expected allowed origin resp, got: %v", val)
acao, acaoExists := res.Header["Access-Control-Allow-Origin"]
acam, acamExists := res.Header["Access-Control-Allow-Methods"]
acah, acahExists := res.Header["Access-Control-Allow-Headers"]
acac, acacExists := res.Header["Access-Control-Allow-Credentials"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
}
if !acamExists || acam[0] != tc.expAllowedMethods {
subT.Errorf("Expected allowed methods, got: %v", acam)
}
if !acahExists || acah[0] != tc.expAllowedHeaders {
subT.Errorf("Expected allowed headers, got: %v", acah)
}
if !acacExists || acac[0] != "true" {
subT.Errorf("Expected allowed credentials, got: %v", acac)
}
} else {
if acaoExists {
subT.Errorf("Expected not allowed origin, got: %v", acao)
}
if acamExists {
subT.Errorf("Expected not allowed methods, got: %v", acam)
}
if acahExists {
subT.Errorf("Expected not allowed headers, got: %v", acah)
}
if acacExists {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
vary, varyExists := res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVaryPF {
subT.Errorf("Expected vary %q, got: %q", tc.expVaryPF, strings.Join(vary, ","))
}

// actual request lacking credentials -> rejected by basic_auth AC
req, err = http.NewRequest(requestMethod, "http://localhost:8080"+tc.path, nil)
helper.Must(err)

req.Header.Set("Origin", tc.origin)

res, err = client.Do(req)
helper.Must(err)

helper.Must(res.Body.Close())

if res.StatusCode != http.StatusUnauthorized {
subT.Fatalf("%q: expected Status %d, got: %d", tc.path, http.StatusUnauthorized, res.StatusCode)
}

acao, acaoExists = res.Header["Access-Control-Allow-Origin"]
acac, acacExists = res.Header["Access-Control-Allow-Credentials"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
}
if !acacExists || acac[0] != "true" {
subT.Errorf("Expected allowed credentials, got: %v", acac)
}
} else {
if acaoExists {
subT.Errorf("Expected not allowed origin, got: %v", acao)
}
if acacExists {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
vary, varyExists = res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVary {
subT.Errorf("Expected vary %q, got: %q", tc.expVary, strings.Join(vary, ","))
}

// actual request with credentials
req, err = http.NewRequest(requestMethod, "http://localhost:8080"+tc.path, nil)
helper.Must(err)

req.Header.Set("Origin", tc.origin)
req.Header.Set("Authorization", "Basic Zm9vOmFzZGY=")

res, err = client.Do(req)
helper.Must(err)

helper.Must(res.Body.Close())

if res.StatusCode != http.StatusOK {
subT.Fatalf("%q: expected Status %d, got: %d", tc.path, http.StatusOK, res.StatusCode)
}

acao, acaoExists = res.Header["Access-Control-Allow-Origin"]
acac, acacExists = res.Header["Access-Control-Allow-Credentials"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
}
if !acacExists || acac[0] != "true" {
subT.Errorf("Expected allowed credentials, got: %v", acac)
}
} else {
if acaoExists {
subT.Errorf("Expected not allowed origin, got: %v", acao)
}
if acacExists {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
vary, varyExists = res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVaryCred {
subT.Errorf("Expected vary %q, got: %q", tc.expVaryCred, strings.Join(vary, ","))
}
})
}
Expand Down
11 changes: 11 additions & 0 deletions server/testdata/integration/config/06_couper.hcl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
server "cors" {
access_control = ["ba"]

files {
document_root = "./"
cors {
allowed_origins = "a.com"
allow_credentials = true
}
}

Expand All @@ -11,16 +14,24 @@ server "cors" {
bootstrap_file = "06_couper.hcl"
cors {
allowed_origins = "b.com"
allow_credentials = true
}
}

api {
base_path = "/api"
cors {
allowed_origins = "c.com"
allow_credentials = true
}
endpoint "/" {
response {}
}
}
}
definitions {
basic_auth "ba" {
user = "foo"
password = "asdf"
}
}

0 comments on commit 208c331

Please sign in to comment.