diff --git a/http/cors.go b/http/cors.go index 2a25a377fc67..a01228be2da6 100644 --- a/http/cors.go +++ b/http/cors.go @@ -9,11 +9,6 @@ import ( "github.com/hashicorp/vault/vault" ) -var preflightHeaders = map[string]string{ - "Access-Control-Allow-Headers": "*", - "Access-Control-Max-Age": "300", -} - var allowedMethods = []string{ http.MethodDelete, http.MethodGet, @@ -38,8 +33,7 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler { return } - // Return a 403 if the origin is not - // allowed to make cross-origin requests. + // Return a 403 if the origin is not allowed to make cross-origin requests. if !corsConf.IsValidOrigin(origin) { respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed")) return @@ -56,10 +50,9 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler { // apply headers for preflight requests if req.Method == http.MethodOptions { w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ",")) + w.Header().Set("Access-Control-Allow-Headers", strings.Join(corsConf.AllowedHeaders, ",")) + w.Header().Set("Access-Control-Max-Age", "300") - for k, v := range preflightHeaders { - w.Header().Set(k, v) - } return } diff --git a/http/handler_test.go b/http/handler_test.go index 40885fbd8249..0c9e08eda40c 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "testing" "github.com/hashicorp/go-cleanhttp" @@ -21,7 +22,7 @@ func TestHandler_cors(t *testing.T) { // Enable CORS and allow from any origin for testing. corsConfig := core.CORSConfig() - err := corsConfig.Enable([]string{addr}) + err := corsConfig.Enable([]string{addr}, nil) if err != nil { t.Fatalf("Error enabling CORS: %s", err) } @@ -78,7 +79,7 @@ func TestHandler_cors(t *testing.T) { // expHeaders := map[string]string{ "Access-Control-Allow-Origin": addr, - "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Headers": strings.Join(stdAllowedHeaders, ","), "Access-Control-Max-Age": "300", "Vary": "Origin", } diff --git a/http/sys_config_cors_test.go b/http/sys_config_cors_test.go new file mode 100644 index 000000000000..bd6c7aeae83d --- /dev/null +++ b/http/sys_config_cors_test.go @@ -0,0 +1,78 @@ +package http + +import ( + "encoding/json" + "net/http" + "reflect" + "testing" + + "github.com/hashicorp/vault/vault" +) + +func TestSysConfigCors(t *testing.T) { + var resp *http.Response + + core, _, token := vault.TestCoreUnsealed(t) + ln, addr := TestServer(t, core) + defer ln.Close() + TestServerAuth(t, addr, token) + + corsConf := core.CORSConfig() + + // Try to enable CORS without providing a value for allowed_origins + resp = testHttpPut(t, token, addr+"/v1/sys/config/cors", map[string]interface{}{ + "allowed_headers": "X-Custom-Header", + }) + + testResponseStatus(t, resp, 500) + + // Enable CORS, but provide an origin this time. + resp = testHttpPut(t, token, addr+"/v1/sys/config/cors", map[string]interface{}{ + "allowed_origins": addr, + "allowed_headers": "X-Custom-Header", + }) + + testResponseStatus(t, resp, 204) + + // Read the CORS configuration + resp = testHttpGet(t, token, addr+"/v1/sys/config/cors") + testResponseStatus(t, resp, 200) + + var actual map[string]interface{} + var expected map[string]interface{} + + lenStdHeaders := len(corsConf.AllowedHeaders) + + expectedHeaders := make([]interface{}, lenStdHeaders) + + for i := range corsConf.AllowedHeaders { + expectedHeaders[i] = corsConf.AllowedHeaders[i] + } + + expected = map[string]interface{}{ + "lease_id": "", + "renewable": false, + "lease_duration": json.Number("0"), + "wrap_info": nil, + "warnings": nil, + "auth": nil, + "data": map[string]interface{}{ + "enabled": true, + "allowed_origins": []interface{}{addr}, + "allowed_headers": expectedHeaders, + }, + "enabled": true, + "allowed_origins": []interface{}{addr}, + "allowed_headers": expectedHeaders, + } + + testResponseStatus(t, resp, 200) + + testResponseBody(t, resp, &actual) + expected["request_id"] = actual["request_id"] + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("bad: expected: %#v\nactual: %#v", expected, actual) + } + +} diff --git a/vault/core.go b/vault/core.go index 75b754ee7050..861e568edb33 100644 --- a/vault/core.go +++ b/vault/core.go @@ -459,8 +459,8 @@ func NewCore(conf *CoreConfig) (*Core, error) { enableMlock: !conf.DisableMlock, } - // Load CORS config and provide core c.corsConfig = &CORSConfig{core: c} + // Load CORS config and provide a value for the core field. // Wrap the physical backend in a cache layer if enabled and not already wrapped if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache { diff --git a/vault/cors.go b/vault/cors.go index c1fd961284c2..de201200bf43 100644 --- a/vault/cors.go +++ b/vault/cors.go @@ -15,12 +15,24 @@ const ( CORSEnabled ) +var stdAllowedHeaders = []string{ + "Content-Type", + "X-Requested-With", + "X-Vault-AWS-IAM-Server-ID", + "X-Vault-MFA", + "X-Vault-No-Request-Forwarding", + "X-Vault-Token", + "X-Vault-Wrap-Format", + "X-Vault-Wrap-TTL", +} + // CORSConfig stores the state of the CORS configuration. type CORSConfig struct { sync.RWMutex `json:"-"` core *Core Enabled uint32 `json:"enabled"` AllowedOrigins []string `json:"allowed_origins,omitempty"` + AllowedHeaders []string `json:"allowed_headers,omitempty"` } func (c *Core) saveCORSConfig() error { @@ -31,6 +43,7 @@ func (c *Core) saveCORSConfig() error { } c.corsConfig.RLock() localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins + localConfig.AllowedHeaders = c.corsConfig.AllowedHeaders c.corsConfig.RUnlock() entry, err := logical.StorageEntryJSON("cors", localConfig) @@ -72,9 +85,9 @@ func (c *Core) loadCORSConfig() error { // Enable takes either a '*' or a comma-seprated list of URLs that can make // cross-origin requests to Vault. -func (c *CORSConfig) Enable(urls []string) error { +func (c *CORSConfig) Enable(urls []string, headers []string) error { if len(urls) == 0 { - return errors.New("the list of allowed origins cannot be empty") + return errors.New("at least one origin or the wildcard must be provided.") } if strutil.StrListContains(urls, "*") && len(urls) > 1 { @@ -83,6 +96,15 @@ func (c *CORSConfig) Enable(urls []string) error { c.Lock() c.AllowedOrigins = urls + + // Start with the standard headers to Vault accepts. + c.AllowedHeaders = append(c.AllowedHeaders, stdAllowedHeaders...) + + // Allow the user to add additional headers to the list of + // headers allowed on cross-origin requests. + if len(headers) > 0 { + c.AllowedHeaders = append(c.AllowedHeaders, headers...) + } c.Unlock() atomic.StoreUint32(&c.Enabled, CORSEnabled) @@ -95,12 +117,16 @@ func (c *CORSConfig) IsEnabled() bool { return atomic.LoadUint32(&c.Enabled) == CORSEnabled } -// Disable sets CORS to disabled and clears the allowed origins +// Disable sets CORS to disabled and clears the allowed origins & headers. func (c *CORSConfig) Disable() error { atomic.StoreUint32(&c.Enabled, CORSDisabled) c.Lock() - c.AllowedOrigins = []string(nil) + + c.AllowedOrigins = nil + c.AllowedHeaders = nil + c.Unlock() + return c.core.saveCORSConfig() } diff --git a/vault/logical_system.go b/vault/logical_system.go index 848eea530049..f620e18cbaaf 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -115,6 +115,10 @@ func NewSystemBackend(core *Core) *SystemBackend { Type: framework.TypeCommaStringSlice, Description: "A comma-separated string or array of strings indicating origins that may make cross-origin requests.", }, + "allowed_headers": &framework.FieldSchema{ + Type: framework.TypeCommaStringSlice, + Description: "A comma-separated string or array of strings indicating headers that are allowed on cross-origin requests.", + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -854,6 +858,7 @@ func (b *SystemBackend) handleCORSRead(req *logical.Request, d *framework.FieldD if enabled { corsConf.RLock() resp.Data["allowed_origins"] = corsConf.AllowedOrigins + resp.Data["allowed_headers"] = corsConf.AllowedHeaders corsConf.RUnlock() } @@ -864,12 +869,13 @@ func (b *SystemBackend) handleCORSRead(req *logical.Request, d *framework.FieldD // cross-origin requests and sets the CORS enabled flag to true func (b *SystemBackend) handleCORSUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { origins := d.Get("allowed_origins").([]string) + headers := d.Get("allowed_headers").([]string) - return nil, b.Core.corsConfig.Enable(origins) + return nil, b.Core.corsConfig.Enable(origins, headers) } -// handleCORSDelete clears the allowed origins and sets the CORS enabled flag -// to false +// handleCORSDelete sets the CORS enabled flag to false and clears the list of +// allowed origins & headers. func (b *SystemBackend) handleCORSDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { return nil, b.Core.corsConfig.Disable() } diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 55cf606b4ff2..8b2050670002 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -56,6 +56,7 @@ func TestSystemConfigCORS(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "config/cors") req.Data["allowed_origins"] = "http://www.example.com" + req.Data["allowed_headers"] = "X-Custom-Header" _, err := b.HandleRequest(req) if err != nil { t.Fatal(err) @@ -65,6 +66,7 @@ func TestSystemConfigCORS(t *testing.T) { Data: map[string]interface{}{ "enabled": true, "allowed_origins": []string{"http://www.example.com"}, + "allowed_headers": append(stdAllowedHeaders, "X-Custom-Header"), }, } diff --git a/website/source/api/system/config-cors.html.md b/website/source/api/system/config-cors.html.md index 659a74a38cb4..26c5b42a3295 100644 --- a/website/source/api/system/config-cors.html.md +++ b/website/source/api/system/config-cors.html.md @@ -34,14 +34,23 @@ $ curl \ ```json { "enabled": true, - "allowed_origins": "http://www.example.com" + "allowed_origins": ["http://www.example.com"], + "allowed_headers": [ + "Content-Type", + "X-Requested-With", + "X-Vault-AWS-IAM-Server-ID", + "X-Vault-No-Request-Forwarding", + "X-Vault-Token", + "X-Vault-Wrap-Format", + "X-Vault-Wrap-TTL", + ] } ``` ## Configure CORS Settings This endpoint allows configuring the origins that are permitted to make -cross-origin requests. +cross-origin requests, as well as headers that are allowed on cross-origin requests. | Method | Path | Produces | | :------- | :--------------------------- | :--------------------- | @@ -49,13 +58,16 @@ cross-origin requests. ### Parameters -- `allowed_origins` `(string or string array: "" or [])` – A wildcard (`*`), comma-delimited string, or array of strings specifying the origins that are permitted to make cross-origin requests. +- `allowed_origins` `(string or string array: )` – A wildcard (`*`), comma-delimited string, or array of strings specifying the origins that are permitted to make cross-origin requests. + +- `allowed_headers` `(string or string array: "" or [])` – A comma-delimited string or array of strings specifying headers that are permitted to be on cross-origin requests. Headers set via this parameter will be appended to the list of headers that Vault allows by default. ### Sample Payload ```json { - "allowed_origins": "*" + "allowed_origins": "*", + "allowed_headers": "X-Custom-Header" } ```