diff --git a/cmd/armada/main.go b/cmd/armada/main.go index a54058f9be0..bfb884d9908 100644 --- a/cmd/armada/main.go +++ b/cmd/armada/main.go @@ -20,7 +20,6 @@ import ( "github.com/armadaproject/armada/internal/common/health" "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/profiling" - "github.com/armadaproject/armada/internal/common/serve" "github.com/armadaproject/armada/pkg/api" ) @@ -64,11 +63,9 @@ func main() { }) // Expose profiling endpoints if enabled. - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - g.Go(func() error { - return serve.ListenAndServe(ctx, pprofServer) - }) + err := profiling.SetupPprof(config.Profiling, ctx, g) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } // TODO This starts a separate HTTP server. Is that intended? Should we have a single mux for everything? diff --git a/cmd/binoculars/main.go b/cmd/binoculars/main.go index 4ea6bb51d36..86389fbc91b 100644 --- a/cmd/binoculars/main.go +++ b/cmd/binoculars/main.go @@ -19,9 +19,7 @@ import ( "github.com/armadaproject/armada/internal/common/armadacontext" gateway "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/internal/common/health" - "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/profiling" - "github.com/armadaproject/armada/internal/common/serve" api "github.com/armadaproject/armada/pkg/api/binoculars" ) @@ -47,14 +45,9 @@ func main() { log.Info("Starting...") // Expose profiling endpoints if enabled. - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - go func() { - ctx := armadacontext.Background() - if err := serve.ListenAndServe(ctx, pprofServer); err != nil { - logging.WithStacktrace(ctx, err).Error("pprof server failure") - } - }() + err := profiling.SetupPprof(config.Profiling, armadacontext.Background(), nil) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } stopSignal := make(chan os.Signal, 1) diff --git a/cmd/executor/main.go b/cmd/executor/main.go index 7090b712263..04e53b73f41 100644 --- a/cmd/executor/main.go +++ b/cmd/executor/main.go @@ -7,16 +7,14 @@ import ( "syscall" "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" "github.com/spf13/pflag" "github.com/spf13/viper" "github.com/armadaproject/armada/internal/common" "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/health" - "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/profiling" - "github.com/armadaproject/armada/internal/common/serve" "github.com/armadaproject/armada/internal/executor" "github.com/armadaproject/armada/internal/executor/configuration" ) @@ -41,14 +39,9 @@ func main() { common.LoadConfig(&config, "./config/executor", userSpecifiedConfigs) // Expose profiling endpoints if enabled. - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - go func() { - ctx := armadacontext.Background() - if err := serve.ListenAndServe(ctx, pprofServer); err != nil { - logging.WithStacktrace(ctx, err).Error("pprof server failure") - } - }() + err := profiling.SetupPprof(config.Profiling, armadacontext.Background(), nil) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } mux := http.NewServeMux() @@ -68,7 +61,7 @@ func main() { ) defer shutdownMetricServer() - shutdown, wg := executor.StartUp(armadacontext.Background(), logrus.NewEntry(logrus.StandardLogger()), config) + shutdown, wg := executor.StartUp(armadacontext.Background(), log.NewEntry(log.StandardLogger()), config) go func() { <-shutdownChannel shutdown() diff --git a/cmd/fakeexecutor/main.go b/cmd/fakeexecutor/main.go index 69850e97679..831d56b650e 100644 --- a/cmd/fakeexecutor/main.go +++ b/cmd/fakeexecutor/main.go @@ -5,14 +5,13 @@ import ( "os/signal" "syscall" + log "github.com/sirupsen/logrus" "github.com/spf13/pflag" "github.com/spf13/viper" "github.com/armadaproject/armada/internal/common" "github.com/armadaproject/armada/internal/common/armadacontext" - "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/profiling" - "github.com/armadaproject/armada/internal/common/serve" "github.com/armadaproject/armada/internal/executor/configuration" "github.com/armadaproject/armada/internal/executor/fake" "github.com/armadaproject/armada/internal/executor/fake/context" @@ -38,14 +37,9 @@ func main() { v := common.LoadConfig(&config, "./config/executor", userSpecifiedConfigs) // Expose profiling endpoints if enabled. - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - go func() { - ctx := armadacontext.Background() - if err := serve.ListenAndServe(ctx, pprofServer); err != nil { - logging.WithStacktrace(ctx, err).Error("pprof server failure") - } - }() + err := profiling.SetupPprof(config.Profiling, armadacontext.Background(), nil) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } var nodes []*context.NodeSpec diff --git a/cmd/lookoutv2/main.go b/cmd/lookoutv2/main.go index 83d53913e17..0f6fa3ec7bc 100644 --- a/cmd/lookoutv2/main.go +++ b/cmd/lookoutv2/main.go @@ -14,9 +14,7 @@ import ( "github.com/armadaproject/armada/internal/common" "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" - "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/profiling" - "github.com/armadaproject/armada/internal/common/serve" "github.com/armadaproject/armada/internal/lookoutv2" "github.com/armadaproject/armada/internal/lookoutv2/configuration" "github.com/armadaproject/armada/internal/lookoutv2/gen/restapi" @@ -127,14 +125,9 @@ func main() { common.LoadConfig(&config, "./config/lookoutv2", userSpecifiedConfigs) // Expose profiling endpoints if enabled. - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - go func() { - ctx := armadacontext.Background() - if err := serve.ListenAndServe(ctx, pprofServer); err != nil { - logging.WithStacktrace(ctx, err).Error("pprof server failure") - } - }() + err := profiling.SetupPprof(config.Profiling, armadacontext.Background(), nil) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } log.SetLevel(log.DebugLevel) diff --git a/deployment/armada/templates/deployment.yaml b/deployment/armada/templates/deployment.yaml index 7753dc05e01..7683e97108f 100644 --- a/deployment/armada/templates/deployment.yaml +++ b/deployment/armada/templates/deployment.yaml @@ -69,8 +69,8 @@ spec: - containerPort: {{ .Values.applicationConfig.httpPort }} protocol: TCP name: rest - {{- if .Values.applicationConfig.pprofPort }} - - containerPort: {{ .Values.applicationConfig.pprofPort }} + {{- if and .Values.applicationConfig.profiling .Values.applicationConfig.profiling.port }} + - containerPort: {{ .Values.applicationConfig.profiling.port }} protocol: TCP name: pprof {{- end }} diff --git a/deployment/binoculars/templates/deployment.yaml b/deployment/binoculars/templates/deployment.yaml index 3cf0f536a64..2bb8c051a42 100644 --- a/deployment/binoculars/templates/deployment.yaml +++ b/deployment/binoculars/templates/deployment.yaml @@ -54,8 +54,8 @@ spec: - containerPort: {{ .Values.applicationConfig.httpPort }} protocol: TCP name: web - {{- if .Values.applicationConfig.pprofPort }} - - containerPort: {{ .Values.applicationConfig.pprofPort }} + {{- if and .Values.applicationConfig.profiling .Values.applicationConfig.profiling.port }} + - containerPort: {{ .Values.applicationConfig.profiling.port }} protocol: TCP name: pprof {{- end }} diff --git a/deployment/event-ingester/templates/deployment.yaml b/deployment/event-ingester/templates/deployment.yaml index c006a95dc40..7c7943aba29 100644 --- a/deployment/event-ingester/templates/deployment.yaml +++ b/deployment/event-ingester/templates/deployment.yaml @@ -45,8 +45,8 @@ spec: resources: {{- toYaml .Values.resources | nindent 12 }} ports: - {{- if .Values.applicationConfig.pprofPort }} - - containerPort: {{ .Values.applicationConfig.pprofPort }} + {{- if and .Values.applicationConfig.profiling .Values.applicationConfig.profiling.port }} + - containerPort: {{ .Values.applicationConfig.profiling.port }} protocol: TCP name: pprof {{- end }} diff --git a/deployment/executor/templates/deployment.yaml b/deployment/executor/templates/deployment.yaml index 6c1b2c1b6c5..1694a53a4d7 100644 --- a/deployment/executor/templates/deployment.yaml +++ b/deployment/executor/templates/deployment.yaml @@ -52,8 +52,8 @@ spec: - containerPort: 9001 protocol: TCP name: metrics - {{- if .Values.applicationConfig.pprofPort }} - - containerPort: {{ .Values.applicationConfig.pprofPort }} + {{- if and .Values.applicationConfig.profiling .Values.applicationConfig.profiling.port }} + - containerPort: {{ .Values.applicationConfig.profiling.port }} protocol: TCP name: pprof {{- end }} diff --git a/deployment/lookout-ingester-v2/templates/deployment.yaml b/deployment/lookout-ingester-v2/templates/deployment.yaml index af045eb8a6f..efe46d88554 100644 --- a/deployment/lookout-ingester-v2/templates/deployment.yaml +++ b/deployment/lookout-ingester-v2/templates/deployment.yaml @@ -45,8 +45,8 @@ spec: resources: {{- toYaml .Values.resources | nindent 12 }} ports: - {{- if .Values.applicationConfig.pprofPort }} - - containerPort: {{ .Values.applicationConfig.pprofPort }} + {{- if and .Values.applicationConfig.profiling .Values.applicationConfig.profiling.port }} + - containerPort: {{ .Values.applicationConfig.profiling.port }} protocol: TCP name: pprof {{- end }} diff --git a/deployment/lookout-v2/templates/deployment.yaml b/deployment/lookout-v2/templates/deployment.yaml index 8e4fdb1d0a9..f97d9f9ae8b 100644 --- a/deployment/lookout-v2/templates/deployment.yaml +++ b/deployment/lookout-v2/templates/deployment.yaml @@ -48,8 +48,8 @@ spec: - containerPort: {{ .Values.applicationConfig.apiPort }} protocol: TCP name: web - {{- if .Values.applicationConfig.pprofPort }} - - containerPort: {{ .Values.applicationConfig.pprofPort }} + {{- if and .Values.applicationConfig.profiling .Values.applicationConfig.profiling.port }} + - containerPort: {{ .Values.applicationConfig.profiling.port }} protocol: TCP name: pprof {{- end }} diff --git a/deployment/scheduler/templates/scheduler-ingester-deployment.yaml b/deployment/scheduler/templates/scheduler-ingester-deployment.yaml index 613be5acca0..b0d243d5fc7 100644 --- a/deployment/scheduler/templates/scheduler-ingester-deployment.yaml +++ b/deployment/scheduler/templates/scheduler-ingester-deployment.yaml @@ -45,8 +45,8 @@ spec: resources: {{- toYaml .Values.ingester.resources | nindent 12 }} ports: - {{- if .Values.ingester.applicationConfig.pprofPort }} - - containerPort: {{ .Values.ingester.applicationConfig.pprofPort }} + {{- if and .Values.ingester.applicationConfig.profiling .Values.ingester.applicationConfig.profiling.port }} + - containerPort: {{ .Values.ingester.applicationConfig.profiling.port }} protocol: TCP name: pprof {{- end }} diff --git a/deployment/scheduler/templates/scheduler-statefulset.yaml b/deployment/scheduler/templates/scheduler-statefulset.yaml index eff4ea1581e..6b1c4e8ade2 100644 --- a/deployment/scheduler/templates/scheduler-statefulset.yaml +++ b/deployment/scheduler/templates/scheduler-statefulset.yaml @@ -76,8 +76,8 @@ spec: - containerPort: {{ .Values.scheduler.applicationConfig.metrics.port }} protocol: TCP name: metrics - {{- if .Values.scheduler.applicationConfig.pprofPort }} - - containerPort: {{ .Values.scheduler.applicationConfig.pprofPort }} + {{- if and .Values.scheduler.applicationConfig.profiling .Values.scheduler.applicationConfig.profiling.port }} + - containerPort: {{ .Values.scheduler.applicationConfig.profiling.port }} protocol: TCP name: pprof {{- end }} diff --git a/docs/developer/pprof.md b/docs/developer/pprof.md new file mode 100644 index 00000000000..f5f2ddb0989 --- /dev/null +++ b/docs/developer/pprof.md @@ -0,0 +1,14 @@ +# Use of pprof + +- Go provides a profiling tool called pprof. It's documented at https://pkg.go.dev/net/http/pprof. +- If you wish to use this with Armada, enable the profiling socket with the following config (this should be under `applicationConfig` if using the helm charts). This config will listen on the specified port with no auth. + ``` + profiling: + port: 6060 + auth: + anonymousAuth: true + permissionGroupMapping: + pprof: ["everyone"] + ``` +- It's possible to put pprof behind auth if you want, see [api.md#authentication](./api.md#authentication) and [oidc.md](./oidc.md). +- The helm charts do not currently expose the profiling port via a service and ingress. You can use `kubectl port-forward` to access them. diff --git a/internal/armada/configuration/types.go b/internal/armada/configuration/types.go index c28f40037ba..dd3264c4a48 100644 --- a/internal/armada/configuration/types.go +++ b/internal/armada/configuration/types.go @@ -9,6 +9,7 @@ import ( authconfig "github.com/armadaproject/armada/internal/common/auth/configuration" grpcconfig "github.com/armadaproject/armada/internal/common/grpc/configuration" + profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/pkg/client" ) @@ -19,8 +20,7 @@ type ArmadaConfig struct { GrpcPort uint16 HttpPort uint16 MetricsPort uint16 - // If non-nil, net/http/pprof endpoints are exposed on localhost on this port. - PprofPort *uint16 + Profiling *profilingconfig.ProfilingConfig CorsAllowedOrigins []string GrpcGatewayPath string diff --git a/internal/armada/event/event_test.go b/internal/armada/event/event_test.go index 8a45147ef33..95692bca8c5 100644 --- a/internal/armada/event/event_test.go +++ b/internal/armada/event/event_test.go @@ -326,7 +326,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { err := s.queueRepository.(armadaqueue.QueueRepository).CreateQueue(ctx, q) assert.NoError(t, err) - principal := auth.NewStaticPrincipal("alice", []string{}) + principal := auth.NewStaticPrincipal("alice", "test", []string{}) ctx := auth.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} @@ -351,7 +351,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { err := s.queueRepository.(armadaqueue.QueueRepository).CreateQueue(ctx, q) assert.NoError(t, err) - principal := auth.NewStaticPrincipal("alice", []string{"watch-all-events-group"}) + principal := auth.NewStaticPrincipal("alice", "test", []string{"watch-all-events-group"}) ctx := auth.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} @@ -373,7 +373,7 @@ func TestEventServer_GetJobSetEvents_Permissions(t *testing.T) { err := s.queueRepository.(armadaqueue.QueueRepository).CreateQueue(ctx, q) assert.NoError(t, err) - principal := auth.NewStaticPrincipal("alice", []string{"watch-events-group", "watch-queue-group"}) + principal := auth.NewStaticPrincipal("alice", "test", []string{"watch-events-group", "watch-queue-group"}) ctx := auth.WithPrincipal(armadacontext.Background(), principal) stream := &eventStreamMock{ctx: ctx} diff --git a/internal/armada/submit/testfixtures/test_fixtures.go b/internal/armada/submit/testfixtures/test_fixtures.go index 4aded552ac5..1c3d8cd592c 100644 --- a/internal/armada/submit/testfixtures/test_fixtures.go +++ b/internal/armada/submit/testfixtures/test_fixtures.go @@ -22,7 +22,7 @@ var ( DefaultOwner = "testUser" DefaultJobset = "testJobset" DefaultQueue = queue.Queue{Name: "testQueue"} - DefaultPrincipal = auth.NewStaticPrincipal(DefaultOwner, []string{"groupA"}) + DefaultPrincipal = auth.NewStaticPrincipal(DefaultOwner, "test", []string{"groupA"}) DefaultContainerPort = v1.ContainerPort{ Name: "testContainerPort", ContainerPort: 8080, diff --git a/internal/binoculars/configuration/types.go b/internal/binoculars/configuration/types.go index c2a3931764b..a10c81f59c4 100644 --- a/internal/binoculars/configuration/types.go +++ b/internal/binoculars/configuration/types.go @@ -3,6 +3,7 @@ package configuration import ( "github.com/armadaproject/armada/internal/common/auth/configuration" grpcconfig "github.com/armadaproject/armada/internal/common/grpc/configuration" + profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration" ) type BinocularsConfig struct { @@ -12,8 +13,7 @@ type BinocularsConfig struct { GrpcPort uint16 HttpPort uint16 MetricsPort uint16 - // If non-nil, net/http/pprof endpoints are exposed on localhost on this port. - PprofPort *uint16 + Profiling *profilingconfig.ProfilingConfig CorsAllowedOrigins []string diff --git a/internal/binoculars/service/cordon_test.go b/internal/binoculars/service/cordon_test.go index 2168fc3771e..1f494a19471 100644 --- a/internal/binoculars/service/cordon_test.go +++ b/internal/binoculars/service/cordon_test.go @@ -38,7 +38,7 @@ var ( ) func TestCordonNode(t *testing.T) { - principal := auth.NewStaticPrincipal("principle", []string{}) + principal := auth.NewStaticPrincipal("principle", "test", []string{}) tests := map[string]struct { additionalLabels map[string]string expectedLabels map[string]string diff --git a/internal/common/auth/anonymous.go b/internal/common/auth/anonymous.go index dfeeb126143..7d25a5f7631 100644 --- a/internal/common/auth/anonymous.go +++ b/internal/common/auth/anonymous.go @@ -2,12 +2,13 @@ package auth import "context" +const AnonymousAuthServiceName = "Anonymous" + type AnonymousAuthService struct{} -func (authService *AnonymousAuthService) Name() string { - return "Anonymous" -} +// Default principal used if no principal can be found in a context. +var anonymousPrincipal = NewStaticPrincipal("anonymous", AnonymousAuthServiceName, []string{}) -func (AnonymousAuthService) Authenticate(ctx context.Context) (Principal, error) { +func (AnonymousAuthService) Authenticate(ctx context.Context, authHeader string) (Principal, error) { return anonymousPrincipal, nil } diff --git a/internal/common/auth/authorization_test.go b/internal/common/auth/authorization_test.go index 002d045a50f..9d7b26faa02 100644 --- a/internal/common/auth/authorization_test.go +++ b/internal/common/auth/authorization_test.go @@ -71,8 +71,8 @@ func TestAuthorizer_AuthorizeQueueAction(t *testing.T) { PriorityFactor: 1, } - authorizedPrincipal := NewStaticPrincipal("alice", []string{"submit-job-group"}) - unauthorizedPrincipcal := NewStaticPrincipal("alice", []string{}) + authorizedPrincipal := NewStaticPrincipal("alice", "test", []string{"submit-job-group"}) + unauthorizedPrincipcal := NewStaticPrincipal("alice", "test", []string{}) tests := map[string]struct { ctx *armadacontext.Context diff --git a/internal/common/auth/basic.go b/internal/common/auth/basic.go index dac30220fdd..3f477dcb3f5 100644 --- a/internal/common/auth/basic.go +++ b/internal/common/auth/basic.go @@ -5,12 +5,12 @@ import ( "encoding/base64" "strings" - grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" - "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/configuration" ) +const BasicAuthServiceName = "Basic" + type BasicAuthService struct { users map[string]configuration.UserInfo } @@ -19,35 +19,33 @@ func NewBasicAuthService(users map[string]configuration.UserInfo) *BasicAuthServ return &BasicAuthService{users: users} } -func (authService *BasicAuthService) Name() string { - return "Basic" -} - -func (authService *BasicAuthService) Authenticate(ctx context.Context) (Principal, error) { - basicAuth, err := grpc_auth.AuthFromMD(ctx, "basic") - if err == nil { - payload, err := base64.StdEncoding.DecodeString(basicAuth) - if err != nil { - return nil, &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), - Message: err.Error(), - } +func (authService *BasicAuthService) Authenticate(_ context.Context, authHeader string) (Principal, error) { + authHeaderSplits := strings.SplitN(authHeader, " ", 2) + if len(authHeaderSplits) < 2 || !strings.EqualFold(authHeaderSplits[0], "basic") { + return nil, &armadaerrors.ErrMissingCredentials{ + AuthService: BasicAuthServiceName, + Message: "basic auth header not found", } - pair := strings.SplitN(string(payload), ":", 2) - return authService.loginUser(pair[0], pair[1]) } - return nil, &armadaerrors.ErrMissingCredentials{ - AuthService: authService.Name(), + + payload, err := base64.StdEncoding.DecodeString(authHeaderSplits[1]) + if err != nil { + return nil, &armadaerrors.ErrInvalidCredentials{ + AuthService: BasicAuthServiceName, + Message: err.Error(), + } } + pair := strings.SplitN(string(payload), ":", 2) + return authService.loginUser(pair[0], pair[1]) } func (authService *BasicAuthService) loginUser(username string, password string) (Principal, error) { userInfo, ok := authService.users[username] if ok && userInfo.Password == password { - return NewStaticPrincipal(username, userInfo.Groups), nil + return NewStaticPrincipal(username, BasicAuthServiceName, userInfo.Groups), nil } return nil, &armadaerrors.ErrInvalidCredentials{ Username: username, - AuthService: authService.Name(), + AuthService: BasicAuthServiceName, } } diff --git a/internal/common/auth/basic_test.go b/internal/common/auth/basic_test.go index bed9f3a6443..47c81fb269b 100644 --- a/internal/common/auth/basic_test.go +++ b/internal/common/auth/basic_test.go @@ -17,20 +17,22 @@ func TestBasicAuthService(t *testing.T) { "root": {"toor", []string{}}, }) + auth1 := basicPassword("root", "toor") principal, e := service.Authenticate( - metadata.NewIncomingContext(context.Background(), basicPassword("root", "toor"))) + metadata.NewIncomingContext(context.Background(), auth1), auth1["authorization"][0]) assert.Nil(t, e) assert.Equal(t, principal.GetName(), "root") + auth2 := basicPassword("root", "test") _, e = service.Authenticate( - metadata.NewIncomingContext(context.Background(), basicPassword("root", "test"))) + metadata.NewIncomingContext(context.Background(), auth2), auth2["authorization"][0]) assert.NotNil(t, e) var invalidCredsErr *armadaerrors.ErrInvalidCredentials assert.ErrorAs(t, e, &invalidCredsErr) - _, e = service.Authenticate(context.Background()) + _, e = service.Authenticate(context.Background(), "") var missingCredsErr *armadaerrors.ErrMissingCredentials assert.ErrorAs(t, e, &missingCredsErr) } diff --git a/internal/common/auth/common.go b/internal/common/auth/common.go index 97fd664a6fe..f3bd468cffd 100644 --- a/internal/common/auth/common.go +++ b/internal/common/auth/common.go @@ -2,13 +2,13 @@ package auth import ( "context" - "errors" + "net/http" - grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" - grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "golang.org/x/exp/slices" - "github.com/armadaproject/armada/internal/common/armadaerrors" + grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" + "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" + "github.com/armadaproject/armada/internal/common/util" ) @@ -18,14 +18,12 @@ const principalKey = "principal" // All users are implicitly part of this group. const EveryoneGroup = "everyone" -// Default principal used if no principal can be found in a context. -var anonymousPrincipal = NewStaticPrincipal("anonymous", []string{}) - // Principal represents an entity that can be authenticated (e.g., a user). // Each principal has a name associated with it and may be part of one or more groups. // Scopes and claims are as defined in OpenId. type Principal interface { GetName() string + GetAuthMethod() string GetGroupNames() []string IsInGroup(group string) bool HasScope(scope string) bool @@ -35,24 +33,27 @@ type Principal interface { // Default implementation of the Principal interface. // Here, static refers to the fact that the principal doesn't change once it has been created. type StaticPrincipal struct { - name string - groups map[string]bool - scopes map[string]bool - claims map[string]bool + name string + authMethod string + groups map[string]bool + scopes map[string]bool + claims map[string]bool } -func NewStaticPrincipal(name string, groups []string) *StaticPrincipal { +func NewStaticPrincipal(name string, authMethod string, groups []string) *StaticPrincipal { return &StaticPrincipal{ name, + authMethod, util.StringListToSet(append(groups, EveryoneGroup)), map[string]bool{}, map[string]bool{}, } } -func NewStaticPrincipalWithScopesAndClaims(name string, groups []string, scopes []string, claims []string) *StaticPrincipal { +func NewStaticPrincipalWithScopesAndClaims(name string, authMethod string, groups []string, scopes []string, claims []string) *StaticPrincipal { return &StaticPrincipal{ name, + authMethod, util.StringListToSet(append(groups, EveryoneGroup)), util.StringListToSet(scopes), util.StringListToSet(claims), @@ -75,6 +76,10 @@ func (p *StaticPrincipal) GetName() string { return p.name } +func (p *StaticPrincipal) GetAuthMethod() string { + return p.authMethod +} + func (p *StaticPrincipal) GetGroupNames() []string { names := []string{} for g := range p.groups { @@ -100,42 +105,56 @@ func WithPrincipal(ctx context.Context, principal Principal) context.Context { return context.WithValue(ctx, principalKey, principal) } -// AuthService represents a method of authentication for the gRPC API. +// AuthService represents a method of authentication for the HTTP or gRPC API. // Each implementation represents a particular method, e.g., username/password or OpenID. -// The gRPC server may be started with multiple AuthService to give several options for authentication. +// The HTTP/gRPC server may be started with multiple AuthService to give several options for authentication. type AuthService interface { - Authenticate(ctx context.Context) (Principal, error) - Name() string + Authenticate(ctx context.Context, authHeader string) (Principal, error) } -// CreateMiddlewareAuthFunction returns an authentication function that combines the given -// authentication services. That function returns success if any service successfully +// CreateGrpcMiddlewareAuthFunction for use with GRPC. +// That function returns success if any service successfully // authenticates the user, and an error if all services fail to authenticate. -// The services in authServices are tried one at a time in sequence. -// Successful authentication short-circuits the process. // // If authentication succeeds, the username returned by the authentication service is added to the // request context for logging purposes. -func CreateMiddlewareAuthFunction(authServices []AuthService) grpc_auth.AuthFunc { +func CreateGrpcMiddlewareAuthFunction(authService AuthService) func(ctx context.Context) (context.Context, error) { return func(ctx context.Context) (context.Context, error) { - for _, service := range authServices { - principal, err := service.Authenticate(ctx) - - var missingCredsErr *armadaerrors.ErrMissingCredentials - if errors.As(err, &missingCredsErr) { - // try next auth service - continue - } else if err != nil { - return nil, err - } - - // record user name for request logging - grpc_ctxtags.Extract(ctx).Set("user", principal.GetName()) - grpc_ctxtags.Extract(ctx).Set("authService", service.Name()) - return WithPrincipal(ctx, principal), nil + authHeader := metautils.ExtractIncoming(ctx).Get("authorization") + principal, err := authService.Authenticate(ctx, authHeader) + if err != nil { + return nil, err } - return nil, &armadaerrors.ErrUnauthenticated{ - Message: "Request could not be authenticated with any of the supported schemes.", + + // record username for request logging + grpc_ctxtags.Extract(ctx).Set("user", principal.GetName()) + grpc_ctxtags.Extract(ctx).Set("authService", principal.GetAuthMethod()) + + return WithPrincipal(ctx, principal), nil + } +} + +// CreateHttpMiddlewareAuthFunction for use with GRPC. +// That function returns success if any service successfully +// authenticates the user, and an error if all services fail to authenticate. +// +// If authentication succeeds, the username returned by the authentication service is added to the +// request context for logging purposes. +func CreateHttpMiddlewareAuthFunction(authService AuthService) func(w http.ResponseWriter, r *http.Request) (context.Context, error) { + return func(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + + authHeader := r.Header.Get("Authorization") + principal, err := authService.Authenticate(ctx, authHeader) + if err != nil { + http.Error(w, "auth error:"+err.Error(), http.StatusInternalServerError) + return nil, err } + + // record username for request logging + grpc_ctxtags.Extract(ctx).Set("user", principal.GetName()) + grpc_ctxtags.Extract(ctx).Set("authService", principal.GetAuthMethod()) + + return WithPrincipal(ctx, principal), nil } } diff --git a/internal/common/auth/common_test.go b/internal/common/auth/common_test.go index 43edf1e9a8f..05eecab476e 100644 --- a/internal/common/auth/common_test.go +++ b/internal/common/auth/common_test.go @@ -3,6 +3,7 @@ package auth import ( "context" "errors" + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -10,21 +11,57 @@ import ( "github.com/armadaproject/armada/internal/common/armadaerrors" ) -func TestCreateMiddlewareAuthFunction(t *testing.T) { - principal := NewStaticPrincipal("test", []string{"group"}) - failingService := &fakeAuthService{nil, errors.New("failed")} - serviceWithoutCredentials := &fakeAuthService{nil, &armadaerrors.ErrMissingCredentials{}} +func TestCreateGrpcMiddlewareAuthFunction_Success(t *testing.T) { + principal := NewStaticPrincipal("test", "test", []string{"group"}) successfulService := &fakeAuthService{principal, nil} - _, e := CreateMiddlewareAuthFunction([]AuthService{failingService, successfulService})(context.Background()) + c, e := CreateGrpcMiddlewareAuthFunction(successfulService)(context.Background()) + assert.Nil(t, e) + ctxPrincipal := GetPrincipal(c) + assert.Equal(t, principal, ctxPrincipal, "principal should be added to context") +} + +func TestCreateGrpcMiddlewareAuthFunction_Failure(t *testing.T) { + failingService := &fakeAuthService{nil, errors.New("failed")} + + _, e := CreateGrpcMiddlewareAuthFunction(failingService)(context.Background()) assert.NotNil(t, e, "failed auth should result in error") +} + +func TestCreateGrpcMiddlewareAuthFunction_MissingCredentials(t *testing.T) { + serviceWithoutCredentials := &fakeAuthService{nil, &armadaerrors.ErrMissingCredentials{}} - c, e := CreateMiddlewareAuthFunction([]AuthService{serviceWithoutCredentials, successfulService})(context.Background()) + _, e := CreateGrpcMiddlewareAuthFunction(serviceWithoutCredentials)(context.Background()) + assert.NotNil(t, e, "no credentials should result in error") +} + +func TestCreateHttpMiddlewareAuthFunction_Success(t *testing.T) { + principal := NewStaticPrincipal("test", "test", []string{"group"}) + successfulService := &fakeAuthService{principal, nil} + + w := newFakeResponseWriter() + var r http.Request + c, e := CreateHttpMiddlewareAuthFunction(successfulService)(w, &r) assert.Nil(t, e) ctxPrincipal := GetPrincipal(c) assert.Equal(t, principal, ctxPrincipal, "principal should be added to context") +} - _, e = CreateMiddlewareAuthFunction([]AuthService{serviceWithoutCredentials})(context.Background()) +func TestCreateHttpMiddlewareAuthFunction_Failure(t *testing.T) { + failingService := &fakeAuthService{nil, errors.New("failed")} + + w := newFakeResponseWriter() + var r http.Request + _, e := CreateHttpMiddlewareAuthFunction(failingService)(w, &r) + assert.NotNil(t, e, "failed auth should result in error") +} + +func TestCreateHttpMiddlewareAuthFunction_MissingCredentials(t *testing.T) { + serviceWithoutCredentials := &fakeAuthService{nil, &armadaerrors.ErrMissingCredentials{}} + + w := newFakeResponseWriter() + var r http.Request + _, e := CreateHttpMiddlewareAuthFunction(serviceWithoutCredentials)(w, &r) assert.NotNil(t, e, "no credentials should result in error") } @@ -33,10 +70,27 @@ type fakeAuthService struct { err error } -func (f *fakeAuthService) Name() string { - return "Fake" +func (f *fakeAuthService) Authenticate(ctx context.Context, authHeader string) (Principal, error) { + return f.principal, f.err } -func (f *fakeAuthService) Authenticate(ctx context.Context) (Principal, error) { - return f.principal, f.err +type fakeResponseWriter struct { + header map[string][]string +} + +func newFakeResponseWriter() http.ResponseWriter { + return fakeResponseWriter{ + header: map[string][]string{}, + } +} + +func (w fakeResponseWriter) Header() http.Header { + return w.header +} + +func (w fakeResponseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func (w fakeResponseWriter) WriteHeader(statusCode int) { } diff --git a/internal/common/auth/kubernetes.go b/internal/common/auth/kubernetes.go index c81c9da9779..f07a45582fa 100644 --- a/internal/common/auth/kubernetes.go +++ b/internal/common/auth/kubernetes.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "github.com/patrickmn/go-cache" authv1 "k8s.io/api/authentication/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -21,6 +20,8 @@ import ( "github.com/armadaproject/armada/internal/common/auth/configuration" ) +const KubernetesAuthServiceName = "KubernetesNative" + type TokenReviewer interface { ReviewToken(ctx context.Context, clusterUrl string, token string, ca []byte) (*authv1.TokenReview, error) } @@ -71,24 +72,18 @@ type CacheData struct { Valid bool `json:"valid"` } -func (authService *KubernetesNativeAuthService) Name() string { - return "KubernetesNative" -} - -func (authService *KubernetesNativeAuthService) Authenticate(ctx context.Context) (Principal, error) { - // Retrieve token from context. - authHeader := strings.SplitN(metautils.ExtractIncoming(ctx).Get("authorization"), " ", 2) - - if len(authHeader) < 2 || authHeader[0] != "KubernetesAuth" { +func (authService *KubernetesNativeAuthService) Authenticate(ctx context.Context, authHeader string) (Principal, error) { + authHeaderSplits := strings.SplitN(authHeader, " ", 2) + if len(authHeaderSplits) < 2 || authHeaderSplits[0] != "KubernetesAuth" { return nil, &armadaerrors.ErrMissingCredentials{ - AuthService: authService.Name(), + AuthService: KubernetesAuthServiceName, } } - token, ca, err := parseAuth(authHeader[1]) + token, ca, err := parseAuth(authHeaderSplits[1]) if err != nil { return nil, &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), + AuthService: KubernetesAuthServiceName, } } @@ -96,14 +91,14 @@ func (authService *KubernetesNativeAuthService) Authenticate(ctx context.Context expirationTime, err := parseTime(token) if err != nil { return nil, &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), + AuthService: KubernetesAuthServiceName, Message: err.Error(), } } if authService.Clock.Now().After(expirationTime) { return nil, &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), + AuthService: KubernetesAuthServiceName, Message: "invalid token, expired", } } @@ -113,10 +108,10 @@ func (authService *KubernetesNativeAuthService) Authenticate(ctx context.Context if found { if cacheInfo, ok := data.(CacheData); ok { if cacheInfo.Valid { - return NewStaticPrincipal(cacheInfo.Name, []string{cacheInfo.Name}), nil + return NewStaticPrincipal(cacheInfo.Name, KubernetesAuthServiceName, []string{cacheInfo.Name}), nil } else { return nil, &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), + AuthService: KubernetesAuthServiceName, Message: "token invalid", } } @@ -127,7 +122,7 @@ func (authService *KubernetesNativeAuthService) Authenticate(ctx context.Context url, err := authService.getClusterURL(token) if err != nil { return nil, &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), + AuthService: KubernetesAuthServiceName, Message: err.Error(), } } @@ -136,7 +131,10 @@ func (authService *KubernetesNativeAuthService) Authenticate(ctx context.Context name, err := authService.reviewToken(ctx, url, token, []byte(ca)) if err != nil { // reviewToken returns appropriate armadaerrors. - return nil, err + return nil, &armadaerrors.ErrInvalidCredentials{ + AuthService: KubernetesAuthServiceName, + Message: err.Error(), + } } // Add to cache @@ -149,7 +147,7 @@ func (authService *KubernetesNativeAuthService) Authenticate(ctx context.Context expirationTime.Sub(time.Now())) // Return very basic Principal - return NewStaticPrincipal(name, []string{name}), nil + return NewStaticPrincipal(name, KubernetesAuthServiceName, []string{name}), nil } func (authService *KubernetesNativeAuthService) getClusterURL(token string) (string, error) { @@ -185,7 +183,7 @@ func (authService *KubernetesNativeAuthService) reviewToken(ctx context.Context, // TODO(clif) Hard to tell if this should be internal auth error // or invalid creds still. return "", &armadaerrors.ErrInternalAuthServiceError{ - AuthService: authService.Name(), + AuthService: KubernetesAuthServiceName, Message: err.Error(), } } @@ -193,7 +191,7 @@ func (authService *KubernetesNativeAuthService) reviewToken(ctx context.Context, if !result.Status.Authenticated { authService.TokenCache.Set(token, CacheData{Valid: false}, time.Duration(authService.InvalidTokenExpiry)) return "", &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), + AuthService: KubernetesAuthServiceName, Message: "provided token was rejected by TokenReview", } } diff --git a/internal/common/auth/kubernetes_test.go b/internal/common/auth/kubernetes_test.go index d178a8390a3..01da9d02b1b 100644 --- a/internal/common/auth/kubernetes_test.go +++ b/internal/common/auth/kubernetes_test.go @@ -157,9 +157,9 @@ func TestAuthenticateKubernetes(t *testing.T) { // Authenticate authService := createTestAuthService(tempdir+"/", true, testName, testTokenIss) - principal, err := authService.Authenticate(ctx) + principal, err := authService.Authenticate(ctx, payload) - expected := NewStaticPrincipal(testName, []string{testName}) + expected := NewStaticPrincipal(testName, KubernetesAuthServiceName, []string{testName}) assert.NoError(t, err) assert.Equal(t, expected, principal) } diff --git a/internal/common/auth/multi.go b/internal/common/auth/multi.go new file mode 100644 index 00000000000..c74b245e94b --- /dev/null +++ b/internal/common/auth/multi.go @@ -0,0 +1,36 @@ +package auth + +import ( + "context" + "errors" + + "github.com/armadaproject/armada/internal/common/armadaerrors" +) + +type MultiAuthService struct { + authServices []AuthService +} + +func NewMultiAuthService(authServices []AuthService) *MultiAuthService { + return &MultiAuthService{authServices: authServices} +} + +// Authenticate - The services in authServices are tried one at a time in sequence. +// Successful authentication short-circuits the process. +func (multi *MultiAuthService) Authenticate(ctx context.Context, authHeader string) (Principal, error) { + for _, service := range multi.authServices { + principal, err := service.Authenticate(ctx, authHeader) + + var missingCredsErr *armadaerrors.ErrMissingCredentials + if errors.As(err, &missingCredsErr) { + // try next auth service + continue + } else if err != nil { + return nil, err + } + return principal, nil + } + return nil, &armadaerrors.ErrUnauthenticated{ + Message: "Request could not be authenticated with any of the supported schemes.", + } +} diff --git a/internal/common/auth/multi_test.go b/internal/common/auth/multi_test.go new file mode 100644 index 00000000000..531c4cd79a2 --- /dev/null +++ b/internal/common/auth/multi_test.go @@ -0,0 +1,31 @@ +package auth + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/armadaproject/armada/internal/common/armadaerrors" +) + +func TestNewMultiAuthService(t *testing.T) { + principal := NewStaticPrincipal("test", "test", []string{"group"}) + failingService := &fakeAuthService{nil, errors.New("failed")} + serviceWithoutCredentials := &fakeAuthService{nil, &armadaerrors.ErrMissingCredentials{}} + successfulService := &fakeAuthService{principal, nil} + + sut := NewMultiAuthService([]AuthService{failingService, successfulService}) + _, e := sut.Authenticate(context.Background(), "") + assert.NotNil(t, e, "failed auth should result in error") + + sut = NewMultiAuthService([]AuthService{serviceWithoutCredentials, successfulService}) + p, e := sut.Authenticate(context.Background(), "") + assert.Nil(t, e) + assert.Equal(t, principal.GetName(), p.GetName(), "principal should be returned") + + sut = NewMultiAuthService([]AuthService{serviceWithoutCredentials}) + p, e = sut.Authenticate(context.Background(), "") + assert.NotNil(t, e, "no credentials should result in error") +} diff --git a/internal/common/auth/oidc.go b/internal/common/auth/oidc.go index b0219f8732e..5062b217cfa 100644 --- a/internal/common/auth/oidc.go +++ b/internal/common/auth/oidc.go @@ -6,13 +6,14 @@ import ( "strings" "github.com/coreos/go-oidc" - grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/configuration" "github.com/armadaproject/armada/internal/common/auth/permission" ) +const OidcAuthServiceName = "OIDC" + type PermissionClaimQueries map[permission.Permission]string type OpenIdAuthService struct { @@ -37,23 +38,18 @@ func NewOpenIdAuthService(verifier *oidc.IDTokenVerifier, groupsClaim string) *O return &OpenIdAuthService{verifier, groupsClaim} } -func (authService *OpenIdAuthService) Name() string { - return "OIDC" -} - -func (authService *OpenIdAuthService) Authenticate(ctx context.Context) (Principal, error) { - token, err := grpc_auth.AuthFromMD(ctx, "bearer") - if err != nil { +func (authService *OpenIdAuthService) Authenticate(ctx context.Context, authHeader string) (Principal, error) { + authHeaderSplits := strings.SplitN(authHeader, " ", 2) + if len(authHeaderSplits) < 2 || !strings.EqualFold(authHeaderSplits[0], "bearer") { return nil, &armadaerrors.ErrMissingCredentials{ - AuthService: authService.Name(), - Message: err.Error(), + AuthService: OidcAuthServiceName, } } - verifiedToken, err := authService.verifier.Verify(ctx, token) + verifiedToken, err := authService.verifier.Verify(ctx, authHeaderSplits[1]) if err != nil { return nil, &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), + AuthService: OidcAuthServiceName, Message: err.Error(), } } @@ -61,13 +57,14 @@ func (authService *OpenIdAuthService) Authenticate(ctx context.Context) (Princip rawClaims, err := extractRawClaims(verifiedToken) if err != nil { return nil, &armadaerrors.ErrInvalidCredentials{ - AuthService: authService.Name(), + AuthService: OidcAuthServiceName, Message: err.Error(), } } return NewStaticPrincipalWithScopesAndClaims( verifiedToken.Subject, + OidcAuthServiceName, authService.extractGroups(rawClaims), authService.extractScopes(verifiedToken), authService.extractClaims(rawClaims)), nil diff --git a/internal/common/auth/oidc_test.go b/internal/common/auth/oidc_test.go index f79c4d0a1eb..9152409763f 100644 --- a/internal/common/auth/oidc_test.go +++ b/internal/common/auth/oidc_test.go @@ -29,23 +29,24 @@ func TestOpenIdAuthService(t *testing.T) { keySet := &fakeKeySet{payload, nil} verifier := oidc.NewVerifier("fake_issuer", keySet, &oidc.Config{SkipClientIDCheck: true}) + authHeader := "bearer " + token ctx := metadata.NewIncomingContext(context.Background(), map[string][]string{ - "authorization": {"bearer " + token}, + "authorization": {authHeader}, }) service := NewOpenIdAuthService(verifier, "groups") - principal, e := service.Authenticate(ctx) + principal, e := service.Authenticate(ctx, authHeader) assert.Nil(t, e) assert.Equal(t, "me", principal.GetName()) assert.True(t, principal.IsInGroup("test")) - _, e = service.Authenticate(context.Background()) + _, e = service.Authenticate(context.Background(), "") var missingCredsErr *armadaerrors.ErrMissingCredentials assert.ErrorAs(t, e, &missingCredsErr) keySet.err = errors.New("wrong signature") - _, e = service.Authenticate(ctx) + _, e = service.Authenticate(ctx, authHeader) assert.NotNil(t, e) assert.NotErrorIs(t, e, missingCredsErr) } diff --git a/internal/common/auth/permissions_test.go b/internal/common/auth/permissions_test.go index 9590c5e0896..94eec3e8fb6 100644 --- a/internal/common/auth/permissions_test.go +++ b/internal/common/auth/permissions_test.go @@ -68,16 +68,16 @@ func init() { map[permission.Permission][]string{executeJobsPermission: {executorClaim}}, ) - admin = NewStaticPrincipal("admin", []string{adminGroup}) - submitter = NewStaticPrincipal("submitter", []string{submitterGroup}) - otherUser = NewStaticPrincipal("otherUser", []string{unimportantGroup}) - userWithCreatorScope = NewStaticPrincipalWithScopesAndClaims("creatorScopeUser", []string{unimportantGroup}, []string{creatorScope}, []string{}) - userWithExecutorClaim = NewStaticPrincipalWithScopesAndClaims("executorClaimUser", []string{unimportantGroup}, []string{}, []string{executorClaim}) - - thingOwnerDirect = NewStaticPrincipal("thingOwnerDirect", []string{}) - thingOwnerDirectAndViaGroup = NewStaticPrincipal("thingOwnerDirectAndViaGroup", []string{thingOwningGroup}) - thingOwnerViaGroup = NewStaticPrincipal("thingOwnerViaGroup", []string{thingOwningGroup}) - thingNonOwner = NewStaticPrincipal("thingNonOwner", []string{unimportantGroup}) + admin = NewStaticPrincipal("admin", "test", []string{adminGroup}) + submitter = NewStaticPrincipal("submitter", "test", []string{submitterGroup}) + otherUser = NewStaticPrincipal("otherUser", "test", []string{unimportantGroup}) + userWithCreatorScope = NewStaticPrincipalWithScopesAndClaims("creatorScopeUser", "test", []string{unimportantGroup}, []string{creatorScope}, []string{}) + userWithExecutorClaim = NewStaticPrincipalWithScopesAndClaims("executorClaimUser", "test", []string{unimportantGroup}, []string{}, []string{executorClaim}) + + thingOwnerDirect = NewStaticPrincipal("thingOwnerDirect", "test", []string{}) + thingOwnerDirectAndViaGroup = NewStaticPrincipal("thingOwnerDirectAndViaGroup", "test", []string{thingOwningGroup}) + thingOwnerViaGroup = NewStaticPrincipal("thingOwnerViaGroup", "test", []string{thingOwningGroup}) + thingNonOwner = NewStaticPrincipal("thingNonOwner", "test", []string{unimportantGroup}) ownedThing = &OwnedThing{ []string{thingOwnerDirect.GetName(), thingOwnerDirectAndViaGroup.GetName()}, []string{thingOwningGroup}, diff --git a/internal/common/grpc/grpc.go b/internal/common/grpc/grpc.go index 045c448fc15..c4f858f3ba1 100644 --- a/internal/common/grpc/grpc.go +++ b/internal/common/grpc/grpc.go @@ -73,7 +73,7 @@ func CreateGrpcServer( // Authentication // The provided authServices represents a list of services that can be used to authenticate // the client (e.g., username/password and OpenId). authFunction is a combination of these. - authFunction := auth.CreateMiddlewareAuthFunction(authServices) + authFunction := auth.CreateGrpcMiddlewareAuthFunction(auth.NewMultiAuthService(authServices)) unaryInterceptors = append(unaryInterceptors, grpc_auth.UnaryServerInterceptor(authFunction)) streamInterceptors = append(streamInterceptors, grpc_auth.StreamServerInterceptor(authFunction)) diff --git a/internal/common/profiling/configuration/types.go b/internal/common/profiling/configuration/types.go new file mode 100644 index 00000000000..80194c834a8 --- /dev/null +++ b/internal/common/profiling/configuration/types.go @@ -0,0 +1,10 @@ +package configuration + +import ( + authconfig "github.com/armadaproject/armada/internal/common/auth/configuration" +) + +type ProfilingConfig struct { + Port uint16 + Auth *authconfig.AuthConfig +} diff --git a/internal/common/profiling/http.go b/internal/common/profiling/http.go index 69337153c7d..cea7021bc13 100644 --- a/internal/common/profiling/http.go +++ b/internal/common/profiling/http.go @@ -1,11 +1,67 @@ package profiling import ( + "context" "fmt" "net/http" _ "net/http/pprof" + + log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" + + "github.com/armadaproject/armada/internal/common/armadacontext" + "github.com/armadaproject/armada/internal/common/auth" + "github.com/armadaproject/armada/internal/common/logging" + "github.com/armadaproject/armada/internal/common/profiling/configuration" + "github.com/armadaproject/armada/internal/common/serve" ) +func SetupPprof(config *configuration.ProfilingConfig, ctx *armadacontext.Context, g *errgroup.Group) error { + if config == nil { + log.Infof("Pprof server not configured, skipping") + return nil + } + + log.Infof("Setting up pprof server on port %d", config.Port) + + authServices, err := auth.ConfigureAuth(*config.Auth) + if err != nil { + return fmt.Errorf("error configuring pprof auth :%v", err) + } + + authenticationFunc := auth.CreateHttpMiddlewareAuthFunction(auth.NewMultiAuthService(authServices)) + + authorizer := auth.NewAuthorizer( + auth.NewPrincipalPermissionChecker( + config.Auth.PermissionGroupMapping, + config.Auth.PermissionScopeMapping, + config.Auth.PermissionClaimMapping, + ), + ) + + pprofServer := setupPprofHttpServerWithAuth(config.Port, func(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx, err := authenticationFunc(w, r) + if err != nil { + return nil, err + } + + return ctx, authorizer.AuthorizeAction(armadacontext.FromGrpcCtx(ctx), "pprof") + }) + + if g != nil { + g.Go(func() error { + return serve.ListenAndServe(ctx, pprofServer) + }) + } else { + go func() { + if err := serve.ListenAndServe(ctx, pprofServer); err != nil { + logging.WithStacktrace(ctx, err).Error("pprof server failure") + } + }() + } + return nil +} + // SetupPprofHttpServer does two things: // // 1. Because importing "net/http/pprof" automatically binds profiling endpoints to http.DefaultServeMux, @@ -13,11 +69,33 @@ import ( // are exposed on a separate mux available only from localhost.Hence, this function should be called // before adding any other endpoints to http.DefaultServeMux. // 2. Returns a http server serving net/http/pprof endpoints on localhost:port. -func SetupPprofHttpServer(port uint16) *http.Server { +func setupPprofHttpServerWithAuth(port uint16, authFunc func(w http.ResponseWriter, r *http.Request) (context.Context, error)) *http.Server { pprofMux := http.DefaultServeMux http.DefaultServeMux = http.NewServeMux() + + authInterceptor := AuthInterceptor{ + underlying: pprofMux, + authFunc: authFunc, + } + return &http.Server{ Addr: fmt.Sprintf("localhost:%d", port), - Handler: pprofMux, + Handler: authInterceptor, + } +} + +type AuthInterceptor struct { + underlying http.Handler + authFunc func(w http.ResponseWriter, r *http.Request) (context.Context, error) +} + +func (i AuthInterceptor) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, err := i.authFunc(w, r) + if err != nil { + log.Errorf("Pprof auth failed: %v", err) + return } + principal := auth.GetPrincipal(ctx) + log.Infof("Pprof auth succeeded (method %s, principal %s)", principal.GetAuthMethod(), principal.GetName()) + i.underlying.ServeHTTP(w, r) } diff --git a/internal/eventingester/configuration/types.go b/internal/eventingester/configuration/types.go index 65f67a2ad32..7a660838b63 100644 --- a/internal/eventingester/configuration/types.go +++ b/internal/eventingester/configuration/types.go @@ -6,6 +6,7 @@ import ( "github.com/redis/go-redis/v9" "github.com/armadaproject/armada/internal/armada/configuration" + profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration" ) type EventIngesterConfiguration struct { @@ -29,8 +30,8 @@ type EventIngesterConfiguration struct { EventRetentionPolicy EventRetentionPolicy // List of Regexes which will identify fatal errors when inserting into redis FatalInsertionErrors []string - // If non-nil, net/http/pprof endpoints are exposed on localhost on this port. - PprofPort *uint16 + // If non-nil, configures pprof profiling + Profiling *profilingconfig.ProfilingConfig } // TODO: unpack this into just EventExpirtation diff --git a/internal/eventingester/ingester.go b/internal/eventingester/ingester.go index e47a7eb1afa..db5a402a7c6 100644 --- a/internal/eventingester/ingester.go +++ b/internal/eventingester/ingester.go @@ -13,9 +13,7 @@ import ( "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/ingest" - "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/profiling" - "github.com/armadaproject/armada/internal/common/serve" "github.com/armadaproject/armada/internal/eventingester/configuration" "github.com/armadaproject/armada/internal/eventingester/convert" "github.com/armadaproject/armada/internal/eventingester/metrics" @@ -28,14 +26,9 @@ func Run(config *configuration.EventIngesterConfiguration) { log.Info("Event Ingester Starting") // Expose profiling endpoints if enabled. - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - go func() { - ctx := armadacontext.Background() - if err := serve.ListenAndServe(ctx, pprofServer); err != nil { - logging.WithStacktrace(ctx, err).Error("pprof server failure") - } - }() + err := profiling.SetupPprof(config.Profiling, armadacontext.Background(), nil) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } metrics := metrics.Get() diff --git a/internal/executor/configuration/types.go b/internal/executor/configuration/types.go index 46fadfb46c4..96bd3febf81 100644 --- a/internal/executor/configuration/types.go +++ b/internal/executor/configuration/types.go @@ -5,6 +5,7 @@ import ( "google.golang.org/grpc/keepalive" + profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration" armadaresource "github.com/armadaproject/armada/internal/common/resource" "github.com/armadaproject/armada/internal/executor/configuration/podchecks" "github.com/armadaproject/armada/pkg/client" @@ -160,7 +161,7 @@ const ( type ExecutorConfiguration struct { HttpPort uint16 // If non-nil, net/http/pprof endpoints are exposed on localhost on this port. - PprofPort *uint16 + Profiling *profilingconfig.ProfilingConfig Metric MetricConfiguration Application ApplicationConfiguration ExecutorApiConnection client.ApiConnectionDetails diff --git a/internal/lookoutingesterv2/configuration/types.go b/internal/lookoutingesterv2/configuration/types.go index b6b6b154b7f..0efbbc9a165 100644 --- a/internal/lookoutingesterv2/configuration/types.go +++ b/internal/lookoutingesterv2/configuration/types.go @@ -4,6 +4,7 @@ import ( "time" "github.com/armadaproject/armada/internal/armada/configuration" + profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration" ) type LookoutIngesterV2Configuration struct { @@ -31,8 +32,8 @@ type LookoutIngesterV2Configuration struct { // Between each attempt to store data in the database, there is an exponential backoff (starting out as 1s). // MaxBackoff caps this backoff to whatever it is specified (in seconds) MaxBackoff int - // If non-nil, net/http/pprof endpoints are exposed on localhost on this port. - PprofPort *uint16 + // If non-nil, configures pprof profiling + Profiling *profilingconfig.ProfilingConfig // List of Regexes which will identify fatal errors when inserting into postgres FatalInsertionErrors []string } diff --git a/internal/lookoutingesterv2/ingester.go b/internal/lookoutingesterv2/ingester.go index 8e133be859c..2df02f2301c 100644 --- a/internal/lookoutingesterv2/ingester.go +++ b/internal/lookoutingesterv2/ingester.go @@ -12,9 +12,7 @@ import ( "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/ingest" - "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/profiling" - "github.com/armadaproject/armada/internal/common/serve" "github.com/armadaproject/armada/internal/lookoutingesterv2/configuration" "github.com/armadaproject/armada/internal/lookoutingesterv2/instructions" "github.com/armadaproject/armada/internal/lookoutingesterv2/lookoutdb" @@ -50,14 +48,9 @@ func Run(config *configuration.LookoutIngesterV2Configuration) { } // Expose profiling endpoints if enabled. - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - go func() { - ctx := armadacontext.Background() - if err := serve.ListenAndServe(ctx, pprofServer); err != nil { - logging.WithStacktrace(ctx, err).Error("pprof server failure") - } - }() + err = profiling.SetupPprof(config.Profiling, armadacontext.Background(), nil) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } converter := instructions.NewInstructionConverter(m.Metrics, config.UserAnnotationPrefix, compressor) diff --git a/internal/lookoutv2/configuration/types.go b/internal/lookoutv2/configuration/types.go index 1fb0be21595..a4bb462df39 100644 --- a/internal/lookoutv2/configuration/types.go +++ b/internal/lookoutv2/configuration/types.go @@ -4,13 +4,13 @@ import ( "time" "github.com/armadaproject/armada/internal/armada/configuration" + profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration" ) type LookoutV2Config struct { ApiPort int + Profiling *profilingconfig.ProfilingConfig MetricsPort int - // If non-nil, net/http/pprof endpoints are exposed on localhost on this port. - PprofPort *uint16 CorsAllowedOrigins []string Tls TlsConfig diff --git a/internal/scheduler/configuration/configuration.go b/internal/scheduler/configuration/configuration.go index 6ba418e6653..f78bd3fe3e7 100644 --- a/internal/scheduler/configuration/configuration.go +++ b/internal/scheduler/configuration/configuration.go @@ -11,6 +11,7 @@ import ( "github.com/armadaproject/armada/internal/armada/configuration" authconfig "github.com/armadaproject/armada/internal/common/auth/configuration" grpcconfig "github.com/armadaproject/armada/internal/common/grpc/configuration" + profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration" "github.com/armadaproject/armada/internal/common/types" "github.com/armadaproject/armada/pkg/client" ) @@ -40,8 +41,8 @@ type Configuration struct { Auth authconfig.AuthConfig Grpc grpcconfig.GrpcConfig Http HttpConfig - // If non-nil, net/http/pprof endpoints are exposed on localhost on this port. - PprofPort *uint16 + // If non-nil, configures pprof profiling + Profiling *profilingconfig.ProfilingConfig // Maximum number of strings that should be cached at any one time InternedStringsCacheSize uint32 `validate:"required"` // How often the scheduling cycle should run diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index 5ee162c7488..6d6cfcc11bd 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -29,7 +29,6 @@ import ( "github.com/armadaproject/armada/internal/common/logging" "github.com/armadaproject/armada/internal/common/profiling" "github.com/armadaproject/armada/internal/common/pulsarutils" - "github.com/armadaproject/armada/internal/common/serve" "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/stringinterner" "github.com/armadaproject/armada/internal/common/types" @@ -53,13 +52,11 @@ func Run(config schedulerconfig.Configuration) error { g, ctx := armadacontext.ErrGroup(app.CreateContextWithShutdown()) // //////////////////////////////////////////////////////////////////////// - // Profiling + // Expose profiling endpoints if enabled. // //////////////////////////////////////////////////////////////////////// - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - g.Go(func() error { - return serve.ListenAndServe(ctx, pprofServer) - }) + err := profiling.SetupPprof(config.Profiling, armadacontext.Background(), nil) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } // //////////////////////////////////////////////////////////////////////// diff --git a/internal/scheduleringester/config.go b/internal/scheduleringester/config.go index e08b5f64e1e..2e1f945afbc 100644 --- a/internal/scheduleringester/config.go +++ b/internal/scheduleringester/config.go @@ -4,6 +4,7 @@ import ( "time" "github.com/armadaproject/armada/internal/armada/configuration" + profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration" "github.com/armadaproject/armada/internal/common/types" ) @@ -26,6 +27,6 @@ type Configuration struct { PulsarReceiveTimeout time.Duration // Time for which the pulsar consumer will back off after receiving an error on trying to receive a message PulsarBackoffTime time.Duration - // If non-nil, net/http/pprof endpoints are exposed on localhost on this port. - PprofPort *uint16 + // If non-nil, configures pprof profiling + Profiling *profilingconfig.ProfilingConfig } diff --git a/internal/scheduleringester/ingester.go b/internal/scheduleringester/ingester.go index 0eb3e268643..cc4ae584716 100644 --- a/internal/scheduleringester/ingester.go +++ b/internal/scheduleringester/ingester.go @@ -3,19 +3,16 @@ package scheduleringester import ( "time" - "github.com/armadaproject/armada/internal/common/armadacontext" - "github.com/armadaproject/armada/internal/common/logging" - "github.com/armadaproject/armada/internal/common/profiling" - "github.com/armadaproject/armada/internal/common/serve" - "github.com/apache/pulsar-client-go/pulsar" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/common/app" + "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/database" "github.com/armadaproject/armada/internal/common/ingest" "github.com/armadaproject/armada/internal/common/ingest/metrics" + "github.com/armadaproject/armada/internal/common/profiling" ) // Run will create a pipeline that will take Armada event messages from Pulsar and update the schedulerDb. @@ -36,14 +33,9 @@ func Run(config Configuration) error { } // Expose profiling endpoints if enabled. - if config.PprofPort != nil { - pprofServer := profiling.SetupPprofHttpServer(*config.PprofPort) - go func() { - ctx := armadacontext.Background() - if err := serve.ListenAndServe(ctx, pprofServer); err != nil { - logging.WithStacktrace(ctx, err).Error("pprof server failure") - } - }() + err = profiling.SetupPprof(config.Profiling, armadacontext.Background(), nil) + if err != nil { + log.Fatalf("Pprof setup failed, exiting, %v", err) } ingester := ingest.NewIngestionPipeline[*DbOperationsWithMessageIds](