From 2944cf65d054a384568bae1e2f98c63e06654f80 Mon Sep 17 00:00:00 2001 From: Vasily Tsybenko Date: Tue, 10 Dec 2024 16:43:45 +0200 Subject: [PATCH] Collect Prometheus metrics with token introspection result status --- idptoken/grpc_client.go | 2 +- idptoken/introspector.go | 4 +--- idptoken/provider.go | 2 +- idptoken/provider_test.go | 4 ++-- internal/metrics/metrics.go | 41 +++++++++++++++++++++++++++++++++++-- jwks/client.go | 2 +- jwt/caching_parser.go | 2 +- middleware.go | 22 +++++++++++++++++--- middleware_test.go | 32 ++++++++++++++++++++++++++++- 9 files changed, 96 insertions(+), 15 deletions(-) diff --git a/idptoken/grpc_client.go b/idptoken/grpc_client.go index 102ad97..d68263b 100644 --- a/idptoken/grpc_client.go +++ b/idptoken/grpc_client.go @@ -84,7 +84,7 @@ func NewGRPCClientWithOpts( client: pb.NewIDPTokenServiceClient(conn), clientConn: conn, reqTimeout: opts.RequestTimeout, - promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "grpc_client"), + promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceGRPCClient), }, nil } diff --git a/idptoken/introspector.go b/idptoken/introspector.go index a92ce48..7f77113 100644 --- a/idptoken/introspector.go +++ b/idptoken/introspector.go @@ -32,8 +32,6 @@ import ( const minAccessTokenProviderInvalidationInterval = time.Minute -const tokenIntrospectorPromSource = "token_introspector" - const ( // DefaultIntrospectionClaimsCacheMaxEntries is a default maximum number of entries in the claims cache. // Claims cache is used for storing introspected active tokens. @@ -250,7 +248,7 @@ func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opt } scopeFilterFormURLEncoded := values.Encode() - promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, tokenIntrospectorPromSource) + promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceTokenIntrospector) claimsCache := makeIntrospectionClaimsCache(opts.ClaimsCache, DefaultIntrospectionClaimsCacheMaxEntries, promMetrics) if opts.ClaimsCache.TTL == 0 { diff --git a/idptoken/provider.go b/idptoken/provider.go index 34d938f..6988769 100644 --- a/idptoken/provider.go +++ b/idptoken/provider.go @@ -171,7 +171,7 @@ func NewMultiSourceProviderWithOpts(sources []Source, opts ProviderOpts) *MultiS minRefreshPeriod: opts.MinRefreshPeriod, logger: idputil.PrepareLogger(opts.Logger), tokenIssuers: make(map[string]*oauth2Issuer), - promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "token_provider"), + promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceTokenProvider), customHeaders: opts.CustomHeaders, cache: opts.CustomCacheInstance, httpClient: opts.HTTPClient, diff --git a/idptoken/provider_test.go b/idptoken/provider_test.go index a6165aa..8ec2fc3 100644 --- a/idptoken/provider_test.go +++ b/idptoken/provider_test.go @@ -254,7 +254,7 @@ func TestProviderWithCache(t *testing.T) { metrics.HTTPClientRequestLabelStatusCode: "500", metrics.HTTPClientRequestLabelError: "unexpected_status_code", } - promMetrics := metrics.GetPrometheusMetrics("", "token_provider") + promMetrics := metrics.GetPrometheusMetrics("", metrics.SourceTokenProvider) hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram) testutil.AssertSamplesCountInHistogram(t, hist, 1) }) @@ -287,7 +287,7 @@ func TestProviderWithCache(t *testing.T) { metrics.HTTPClientRequestLabelStatusCode: "200", metrics.HTTPClientRequestLabelError: "", } - promMetrics := metrics.GetPrometheusMetrics("", "token_provider") + promMetrics := metrics.GetPrometheusMetrics("", metrics.SourceTokenProvider) hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram) testutil.AssertSamplesCountInHistogram(t, hist, 1) }) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index ee17f8a..8303d5f 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -39,12 +39,31 @@ const ( GRPCClientRequestLabelMethod = "grpc_method" GRPCClientRequestLabelCode = "grpc_code" + + TokenIntrospectionLabelStatus = "status" ) const ( HTTPRequestErrorDo = "do_request_error" HTTPRequestErrorDecodeBody = "decode_body_error" HTTPRequestErrorUnexpectedStatusCode = "unexpected_status_code" + + TokenIntrospectionStatusActive = "active" + TokenIntrospectionStatusNotActive = "not_active" + TokenIntrospectionStatusNotNeeded = "not_needed" + TokenIntrospectionStatusNotIntrospectable = "not_introspectable" + TokenIntrospectionStatusError = "error" +) + +type Source string + +const ( + SourceJWKSClient Source = "jwks_client" + SourceJWTParser Source = "jwt_parser" + SourceGRPCClient Source = "grpc_client" + SourceTokenIntrospector Source = "token_introspector" + SourceTokenProvider Source = "token_provider" + SourceHTTPMiddleware Source = "http_middleware" ) var requestDurationBuckets = []float64{0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10} @@ -58,12 +77,13 @@ var ( type PrometheusMetrics struct { HTTPClientRequestDuration *prometheus.HistogramVec GRPCClientRequestDuration *prometheus.HistogramVec + TokenIntrospectionsTotal *prometheus.CounterVec TokenClaimsCache *lrucache.PrometheusMetrics TokenNegativeCache *lrucache.PrometheusMetrics EndpointDiscoveryCache *lrucache.PrometheusMetrics } -func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics { +func GetPrometheusMetrics(instance string, source Source) *PrometheusMetrics { prometheusMetricsOnce.Do(func() { prometheusMetrics = newPrometheusMetrics() prometheusMetrics.MustRegister() @@ -73,7 +93,7 @@ func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics { } return prometheusMetrics.MustCurryWith(map[string]string{ PrometheusLibInstanceLabel: instance, - PrometheusLibSourceLabel: source, + PrometheusLibSourceLabel: string(source), }) } @@ -95,6 +115,7 @@ func newPrometheusMetrics() *PrometheusMetrics { makeLabelNames(HTTPClientRequestLabelMethod, HTTPClientRequestLabelURL, HTTPClientRequestLabelStatusCode, HTTPClientRequestLabelError), ) + grpcClientReqDuration := prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: PrometheusNamespace, @@ -106,6 +127,16 @@ func newPrometheusMetrics() *PrometheusMetrics { makeLabelNames(GRPCClientRequestLabelMethod, GRPCClientRequestLabelCode), ) + tokenIntrospectionsTotal := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: PrometheusNamespace, + Name: "token_introspections_total", + Help: "Total number of tokens' introspections", + ConstLabels: PrometheusLabels(), + }, + makeLabelNames(TokenIntrospectionLabelStatus), + ) + tokenClaimsCache := lrucache.NewPrometheusMetricsWithOpts(lrucache.PrometheusMetricsOpts{ Namespace: PrometheusNamespace + "_token_claims", ConstLabels: PrometheusLabels(), @@ -127,6 +158,7 @@ func newPrometheusMetrics() *PrometheusMetrics { return &PrometheusMetrics{ HTTPClientRequestDuration: httpClientReqDuration, GRPCClientRequestDuration: grpcClientReqDuration, + TokenIntrospectionsTotal: tokenIntrospectionsTotal, TokenClaimsCache: tokenClaimsCache, TokenNegativeCache: tokenNegativeCache, EndpointDiscoveryCache: endpointDiscoveryCache, @@ -138,6 +170,7 @@ func (pm *PrometheusMetrics) MustCurryWith(labels prometheus.Labels) *Prometheus return &PrometheusMetrics{ HTTPClientRequestDuration: pm.HTTPClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec), GRPCClientRequestDuration: pm.GRPCClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec), + TokenIntrospectionsTotal: pm.TokenIntrospectionsTotal.MustCurryWith(labels), TokenClaimsCache: pm.TokenClaimsCache.MustCurryWith(labels), TokenNegativeCache: pm.TokenNegativeCache.MustCurryWith(labels), EndpointDiscoveryCache: pm.EndpointDiscoveryCache.MustCurryWith(labels), @@ -183,3 +216,7 @@ func (pm *PrometheusMetrics) ObserveGRPCClientRequest( GRPCClientRequestLabelCode: code.String(), }).Observe(elapsed.Seconds()) } + +func (pm *PrometheusMetrics) IncTokenIntrospectionsTotal(status string) { + pm.TokenIntrospectionsTotal.With(prometheus.Labels{TokenIntrospectionLabelStatus: status}).Inc() +} diff --git a/jwks/client.go b/jwks/client.go index 28973c3..205a30b 100644 --- a/jwks/client.go +++ b/jwks/client.go @@ -57,7 +57,7 @@ func NewClient() *Client { // NewClientWithOpts returns a new Client with options. func NewClientWithOpts(opts ClientOpts) *Client { - promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "jwks_client") + promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceJWKSClient) if opts.HTTPClient == nil { opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.LoggerProvider) } diff --git a/jwt/caching_parser.go b/jwt/caching_parser.go index b7fa6ae..5d6bcd6 100644 --- a/jwt/caching_parser.go +++ b/jwt/caching_parser.go @@ -48,7 +48,7 @@ func NewCachingParser(keysProvider KeysProvider) (*CachingParser, error) { func NewCachingParserWithOpts( keysProvider KeysProvider, opts CachingParserOpts, ) (*CachingParser, error) { - promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, "jwt_parser") + promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, metrics.SourceJWTParser) if opts.CacheMaxEntries == 0 { opts.CacheMaxEntries = DefaultClaimsCacheMaxEntries } diff --git a/middleware.go b/middleware.go index 1b5b69f..8dc2912 100644 --- a/middleware.go +++ b/middleware.go @@ -18,6 +18,7 @@ import ( "github.com/acronis/go-authkit/idptoken" "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/internal/metrics" "github.com/acronis/go-authkit/jwt" ) @@ -70,12 +71,14 @@ type jwtAuthHandler struct { verifyAccess func(r *http.Request, claims jwt.Claims) bool tokenIntrospector TokenIntrospector loggerProvider func(ctx context.Context) log.FieldLogger + promMetrics *metrics.PrometheusMetrics } type jwtAuthMiddlewareOpts struct { - verifyAccess func(r *http.Request, claims jwt.Claims) bool - tokenIntrospector TokenIntrospector - loggerProvider func(ctx context.Context) log.FieldLogger + verifyAccess func(r *http.Request, claims jwt.Claims) bool + tokenIntrospector TokenIntrospector + loggerProvider func(ctx context.Context) log.FieldLogger + prometheusLibInstanceLabel string } // JWTAuthMiddlewareOption is an option for JWTAuthMiddleware. @@ -102,6 +105,13 @@ func WithJWTAuthMiddlewareLoggerProvider(loggerProvider func(ctx context.Context } } +// WithJWTAuthMiddlewarePrometheusLibInstanceLabel is an option to set a label for Prometheus metrics that are used by JWTAuthMiddleware. +func WithJWTAuthMiddlewarePrometheusLibInstanceLabel(label string) JWTAuthMiddlewareOption { + return func(options *jwtAuthMiddlewareOpts) { + options.prometheusLibInstanceLabel = label + } +} + // JWTAuthMiddleware is a middleware that does authentication // by Access Token from the "Authorization" HTTP header of incoming request. // errorDomain is used for error responses. It is usually the name of the service that uses the middleware, @@ -123,6 +133,7 @@ func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthM verifyAccess: options.verifyAccess, tokenIntrospector: options.tokenIntrospector, loggerProvider: options.loggerProvider, + promMetrics: metrics.GetPrometheusMetrics(options.prometheusLibInstanceLabel, metrics.SourceHTTPMiddleware), } } } @@ -146,14 +157,17 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { logFunc("token's introspection is not needed") }) + h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotNeeded) case errors.Is(err, idptoken.ErrTokenNotIntrospectable): // Token is not introspectable by some reason. // In this case, we will parse it as JWT and use it for authZ. h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is", log.Error(err)) + h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotIntrospectable) default: logger := h.logger(reqCtx) logger.Error("token's introspection failed", log.Error(err)) + h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusError) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) return @@ -161,6 +175,7 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } else { if !introspectionResult.IsActive() { h.logger(reqCtx).Warn("token was successfully introspected, but it is not active") + h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotActive) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx)) return @@ -169,6 +184,7 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { logFunc("token was successfully introspected") }) + h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusActive) } } diff --git a/middleware_test.go b/middleware_test.go index 441d939..66d35f2 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/require" "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/internal/metrics" "github.com/acronis/go-authkit/jwt" ) @@ -122,12 +123,18 @@ func TestJWTAuthMiddleware(t *testing.T) { req.Header.Set(HeaderAuthorization, "Bearer a.b.c") resp := httptest.NewRecorder() + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 0) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 0, next.called) + + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 1) }) t.Run("introspection is not needed", func(t *testing.T) { @@ -139,6 +146,9 @@ func TestJWTAuthMiddleware(t *testing.T) { req.Header.Set(HeaderAuthorization, "Bearer a.b.c") resp := httptest.NewRecorder() + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 0) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -148,6 +158,9 @@ func TestJWTAuthMiddleware(t *testing.T) { nextIssuer, err := next.jwtClaims.GetIssuer() require.NoError(t, err) require.Equal(t, issuer, nextIssuer) + + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 1) }) t.Run("ok, token is not introspectable", func(t *testing.T) { @@ -159,6 +172,9 @@ func TestJWTAuthMiddleware(t *testing.T) { req.Header.Set(HeaderAuthorization, "Bearer a.b.c") resp := httptest.NewRecorder() + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 0) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -169,10 +185,12 @@ func TestJWTAuthMiddleware(t *testing.T) { nextIssuer, err := next.jwtClaims.GetIssuer() require.NoError(t, err) require.Equal(t, issuer, nextIssuer) + + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 1) }) t.Run("authentication failed, token is introspected but inactive", func(t *testing.T) { - const issuer = "my-idp.com" parser := &mockJWTParser{} introspector := &mockTokenIntrospector{resultToReturn: &idptoken.DefaultIntrospectionResult{Active: false}} next := &mockJWTAuthMiddlewareNextHandler{} @@ -180,12 +198,18 @@ func TestJWTAuthMiddleware(t *testing.T) { req.Header.Set(HeaderAuthorization, "Bearer a.b.c") resp := httptest.NewRecorder() + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 0) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed) require.Equal(t, 1, introspector.introspectCalled) require.Equal(t, 0, parser.parseCalled) require.Equal(t, 0, next.called) + + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 1) }) t.Run("ok, token is introspected and active", func(t *testing.T) { @@ -198,6 +222,9 @@ func TestJWTAuthMiddleware(t *testing.T) { req.Header.Set(HeaderAuthorization, "Bearer a.b.c") resp := httptest.NewRecorder() + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 0) + JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -208,6 +235,9 @@ func TestJWTAuthMiddleware(t *testing.T) { nextIssuer, err := next.jwtClaims.GetIssuer() require.NoError(t, err) require.Equal(t, issuer, nextIssuer) + + testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware). + TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 1) }) }