diff --git a/http/events.go b/http/events.go index 478f31704250..22dfe35630db 100644 --- a/http/events.go +++ b/http/events.go @@ -2,6 +2,7 @@ package http import ( "context" + "errors" "fmt" "net/http" "strconv" @@ -66,13 +67,25 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo } } -func handleEventsSubscribe(core *vault.Core) http.Handler { +func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger := core.Logger().Named("events-subscribe") - logger.Debug("Got request to", "url", r.URL, "version", r.Proto) ctx := r.Context() + + // ACL check + _, _, err := core.CheckToken(ctx, req, false) + if err != nil { + if errors.Is(err, logical.ErrPermissionDenied) { + respondError(w, http.StatusUnauthorized, logical.ErrPermissionDenied) + return + } + logger.Debug("Error validating token", "error", err) + respondError(w, http.StatusInternalServerError, fmt.Errorf("error validating token")) + return + } + ns, err := namespace.FromContext(ctx) if err != nil { logger.Info("Could not find namespace", "error", err) diff --git a/http/events_test.go b/http/events_test.go index d3729781df4c..645ecd403a4a 100644 --- a/http/events_test.go +++ b/http/events_test.go @@ -2,6 +2,7 @@ package http import ( "context" + "net/http" "strings" "sync/atomic" "testing" @@ -21,6 +22,15 @@ func TestEventsSubscribe(t *testing.T) { ln, addr := TestServer(t, core) defer ln.Close() + // unseal the core + keys, token := vault.TestCoreInit(t, core) + for _, key := range keys { + _, err := core.Unseal(key) + if err != nil { + t.Fatal(err) + } + } + stop := atomic.Bool{} eventType := "abc" @@ -53,7 +63,18 @@ func TestEventsSubscribe(t *testing.T) { t.Cleanup(cancelFunc) wsAddr := strings.Replace(addr, "http", "ws", 1) - conn, _, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/"+eventType+"?json=true", nil) + + // check that the connection fails if we don't have a token + _, _, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/"+eventType+"?json=true", nil) + if err == nil { + t.Error("Expected websocket error but got none") + } else if !strings.HasSuffix(err.Error(), "401") { + t.Errorf("Expected 401 websocket but got %v", err) + } + + conn, _, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/"+eventType+"?json=true", &websocket.DialOptions{ + HTTPHeader: http.Header{"x-vault-token": []string{token}}, + }) if err != nil { t.Fatal(err) } diff --git a/http/logical.go b/http/logical.go index 7dc3ad9e832e..6b12a26f3bfe 100644 --- a/http/logical.go +++ b/http/logical.go @@ -13,7 +13,7 @@ import ( "strings" "time" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/experiments" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/consts" @@ -359,7 +359,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw nsPath = "" } if strings.HasPrefix(r.URL.Path, fmt.Sprintf("/v1/%ssys/events/subscribe/", nsPath)) { - handler := handleEventsSubscribe(core) + handler := handleEventsSubscribe(core, req) handler.ServeHTTP(w, r) return } diff --git a/vault/request_handling.go b/vault/request_handling.go index 1d56d488a7ef..ff041838db2a 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -260,7 +260,7 @@ func (c *Core) fetchACLTokenEntryAndEntity(ctx context.Context, req *logical.Req return acl, te, entity, identityPolicies, nil } -func (c *Core) checkToken(ctx context.Context, req *logical.Request, unauth bool) (*logical.Auth, *logical.TokenEntry, error) { +func (c *Core) CheckToken(ctx context.Context, req *logical.Request, unauth bool) (*logical.Auth, *logical.TokenEntry, error) { defer metrics.MeasureSince([]string{"core", "check_token"}, time.Now()) var acl *ACL @@ -857,7 +857,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp } // Validate the token - auth, te, ctErr := c.checkToken(ctx, req, false) + auth, te, ctErr := c.CheckToken(ctx, req, false) if ctErr == logical.ErrRelativePath { return logical.ErrorResponse(ctErr.Error()), nil, ctErr } @@ -1272,7 +1272,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re // Do an unauth check. This will cause EGP policies to be checked var auth *logical.Auth var ctErr error - auth, _, ctErr = c.checkToken(ctx, req, true) + auth, _, ctErr = c.CheckToken(ctx, req, true) if ctErr == logical.ErrPerfStandbyPleaseForward { return nil, nil, ctErr }