From d8b08adaaae31f9047305f8fa29bb1c985687cc0 Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Fri, 5 Jan 2024 08:07:05 -0600 Subject: [PATCH] Add advanced logging (#11) --- pkg/config/config.go | 20 +++++-- pkg/config/config_test.go | 7 +++ pkg/logger/level.go | 92 +++++++++++++++++++++++++++++ pkg/logger/level_test.go | 85 ++++++++++++++++++++++++++ pkg/logger/logger.go | 40 +++++++++++++ pkg/logger/logger_test.go | 121 ++++++++++++++++++++++++++++++++++++++ pkg/logger/middleware.go | 69 ++++++++++++++++++++++ pkg/server.go | 76 ++++++++++++++++++------ pkg/status.go | 22 ++++--- 9 files changed, 501 insertions(+), 31 deletions(-) create mode 100644 pkg/logger/level.go create mode 100644 pkg/logger/level_test.go create mode 100644 pkg/logger/logger.go create mode 100644 pkg/logger/logger_test.go create mode 100644 pkg/logger/middleware.go diff --git a/pkg/config/config.go b/pkg/config/config.go index cbffd57..d958a19 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -5,15 +5,20 @@ import ( "crypto/x509" "github.com/rotationalio/confire" + "github.com/rs/zerolog" + "github.com/trisacrypto/courier/pkg/logger" "github.com/trisacrypto/trisa/pkg/trust" ) type Config struct { - BindAddr string `split_words:"true" default:":8842"` - Mode string `split_words:"true" default:"release"` - MTLS MTLSConfig `split_words:"true"` - LocalStorage LocalStorageConfig `split_words:"true"` - GCPSecretManager GCPSecretsConfig `split_words:"true"` + Maintenance bool `default:"false"` + BindAddr string `split_words:"true" default:":8842"` + Mode string `split_words:"true" default:"release"` + LogLevel logger.LevelDecoder `split_words:"true" default:"info"` + ConsoleLog bool `split_words:"true" default:"false"` + MTLS MTLSConfig `split_words:"true"` + LocalStorage LocalStorageConfig `split_words:"true"` + GCPSecretManager GCPSecretsConfig `split_words:"true"` processed bool } @@ -95,6 +100,11 @@ func (c Config) Validate() (err error) { return nil } +// Parse and return the zerolog log level for configuring global logging. +func (c Config) GetLogLevel() zerolog.Level { + return zerolog.Level(c.LogLevel) +} + func (c *MTLSConfig) Validate() error { if c.Insecure { return nil diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 7914fbf..27a7f8a 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -4,14 +4,18 @@ import ( "os" "testing" + "github.com/rs/zerolog" "github.com/stretchr/testify/require" "github.com/trisacrypto/courier/pkg/config" ) // Define a test environment for the config tests. var testEnv = map[string]string{ + "COURIER_MAINTENANCE": "true", "COURIER_BIND_ADDR": ":8080", "COURIER_MODE": "debug", + "COURIER_LOG_LEVEL": "warn", + "COURIER_CONSOLE_LOG": "true", "COURIER_MTLS_INSECURE": "false", "COURIER_MTLS_CERT_PATH": "/path/to/cert", "COURIER_MTLS_POOL_PATH": "/path/to/pool", @@ -40,8 +44,11 @@ func TestConfig(t *testing.T) { require.NoError(t, err, "could not create config from test environment") require.False(t, conf.IsZero(), "config should be processed") + require.True(t, conf.Maintenance) require.Equal(t, testEnv["COURIER_BIND_ADDR"], conf.BindAddr) require.Equal(t, testEnv["COURIER_MODE"], conf.Mode) + require.Equal(t, zerolog.WarnLevel, conf.GetLogLevel()) + require.True(t, conf.ConsoleLog) require.False(t, conf.MTLS.Insecure) require.Equal(t, testEnv["COURIER_MTLS_CERT_PATH"], conf.MTLS.CertPath) require.Equal(t, testEnv["COURIER_MTLS_POOL_PATH"], conf.MTLS.PoolPath) diff --git a/pkg/logger/level.go b/pkg/logger/level.go new file mode 100644 index 0000000..10a8356 --- /dev/null +++ b/pkg/logger/level.go @@ -0,0 +1,92 @@ +package logger + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/rs/zerolog" +) + +// LogLevelDecoder deserializes the log level from a config string. +type LevelDecoder zerolog.Level + +// Names of log levels for use in encoding/decoding from strings. +const ( + llPanic = "panic" + llFatal = "fatal" + llError = "error" + llWarn = "warn" + llInfo = "info" + llDebug = "debug" + llTrace = "trace" +) + +// Decode implements confire Decoder interface. +func (ll *LevelDecoder) Decode(value string) error { + value = strings.TrimSpace(strings.ToLower(value)) + switch value { + case llPanic: + *ll = LevelDecoder(zerolog.PanicLevel) + case llFatal: + *ll = LevelDecoder(zerolog.FatalLevel) + case llError: + *ll = LevelDecoder(zerolog.ErrorLevel) + case llWarn: + *ll = LevelDecoder(zerolog.WarnLevel) + case llInfo: + *ll = LevelDecoder(zerolog.InfoLevel) + case llDebug: + *ll = LevelDecoder(zerolog.DebugLevel) + case llTrace: + *ll = LevelDecoder(zerolog.TraceLevel) + default: + return fmt.Errorf("unknown log level %q", value) + } + return nil +} + +// Encode converts the loglevel into a string for use in YAML and JSON +func (ll *LevelDecoder) Encode() (string, error) { + switch zerolog.Level(*ll) { + case zerolog.PanicLevel: + return llPanic, nil + case zerolog.FatalLevel: + return llFatal, nil + case zerolog.ErrorLevel: + return llError, nil + case zerolog.WarnLevel: + return llWarn, nil + case zerolog.InfoLevel: + return llInfo, nil + case zerolog.DebugLevel: + return llDebug, nil + case zerolog.TraceLevel: + return llTrace, nil + default: + return "", fmt.Errorf("unknown log level %d", ll) + } +} + +func (ll LevelDecoder) String() string { + ls, _ := ll.Encode() + return ls +} + +// UnmarshalJSON implements json.Unmarshaler +func (ll *LevelDecoder) UnmarshalJSON(data []byte) error { + var ls string + if err := json.Unmarshal(data, &ls); err != nil { + return err + } + return ll.Decode(ls) +} + +// MarshalJSON implements json.Marshaler +func (ll LevelDecoder) MarshalJSON() ([]byte, error) { + ls, err := ll.Encode() + if err != nil { + return nil, err + } + return json.Marshal(ls) +} diff --git a/pkg/logger/level_test.go b/pkg/logger/level_test.go new file mode 100644 index 0000000..4fc5375 --- /dev/null +++ b/pkg/logger/level_test.go @@ -0,0 +1,85 @@ +package logger_test + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "github.com/trisacrypto/courier/pkg/logger" +) + +func TestLevelDecoder(t *testing.T) { + testTable := []struct { + value string + expected zerolog.Level + }{ + { + "panic", zerolog.PanicLevel, + }, + { + "FATAL", zerolog.FatalLevel, + }, + { + "Error", zerolog.ErrorLevel, + }, + { + " warn ", zerolog.WarnLevel, + }, + { + "iNFo", zerolog.InfoLevel, + }, + { + "debug", zerolog.DebugLevel, + }, + { + "trace", zerolog.TraceLevel, + }, + } + + // Test valid cases + for _, testCase := range testTable { + var level logger.LevelDecoder + err := level.Decode(testCase.value) + require.NoError(t, err) + require.Equal(t, testCase.expected, zerolog.Level(level)) + } + + // Test error case + var level logger.LevelDecoder + err := level.Decode("notalevel") + require.EqualError(t, err, `unknown log level "notalevel"`) +} + +func TestUnmarshaler(t *testing.T) { + type Config struct { + Level logger.LevelDecoder + } + + var jsonConf Config + err := json.Unmarshal([]byte(`{"level": "panic"}`), &jsonConf) + require.NoError(t, err, "could not unmarshal level decoder in json file") + require.Equal(t, zerolog.PanicLevel, zerolog.Level(jsonConf.Level)) +} + +func TestMarshaler(t *testing.T) { + confs := []struct { + Level logger.LevelDecoder `yaml:"level" json:"level"` + }{ + {logger.LevelDecoder(zerolog.PanicLevel)}, + {logger.LevelDecoder(zerolog.FatalLevel)}, + {logger.LevelDecoder(zerolog.ErrorLevel)}, + {logger.LevelDecoder(zerolog.WarnLevel)}, + {logger.LevelDecoder(zerolog.InfoLevel)}, + {logger.LevelDecoder(zerolog.DebugLevel)}, + {logger.LevelDecoder(zerolog.TraceLevel)}, + } + + for _, conf := range confs { + data, err := json.Marshal(conf) + require.NoError(t, err, "could not marshal data into json") + require.Equal(t, []byte(fmt.Sprintf(`{"level":%q}`, &conf.Level)), data) + } + +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go new file mode 100644 index 0000000..78de414 --- /dev/null +++ b/pkg/logger/logger.go @@ -0,0 +1,40 @@ +package logger + +import "github.com/rs/zerolog" + +type severityGCP string + +const ( + GCPAlertLevel severityGCP = "ALERT" + GCPCriticalLevel severityGCP = "CRITICAL" + GCPErrorLevel severityGCP = "ERROR" + GCPWarningLevel severityGCP = "WARNING" + GCPInfoLevel severityGCP = "INFO" + GCPDebugLevel severityGCP = "DEBUG" + + GCPFieldKeySeverity = "severity" + GCPFieldKeyMsg = "message" + GCPFieldKeyTime = "time" +) + +var ( + zerologToGCPLevel = map[zerolog.Level]severityGCP{ + zerolog.PanicLevel: GCPAlertLevel, + zerolog.FatalLevel: GCPCriticalLevel, + zerolog.ErrorLevel: GCPErrorLevel, + zerolog.WarnLevel: GCPWarningLevel, + zerolog.InfoLevel: GCPInfoLevel, + zerolog.DebugLevel: GCPDebugLevel, + zerolog.TraceLevel: GCPDebugLevel, + } +) + +// SeverityHook adds GCP severity levels to zerolog output log messages. +type SeverityHook struct{} + +// Run implements the zerolog.Hook interface. +func (h SeverityHook) Run(e *zerolog.Event, level zerolog.Level, msg string) { + if level != zerolog.NoLevel { + e.Str(GCPFieldKeySeverity, string(zerologToGCPLevel[level])) + } +} diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go new file mode 100644 index 0000000..2209e3a --- /dev/null +++ b/pkg/logger/logger_test.go @@ -0,0 +1,121 @@ +package logger_test + +import ( + "encoding/json" + "errors" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/require" + "github.com/trisacrypto/courier/pkg/logger" +) + +type testWriter struct { + lastLog map[string]interface{} + levels map[zerolog.Level]uint16 +} + +func (w *testWriter) WriteLevel(level zerolog.Level, p []byte) (n int, err error) { + if w.levels == nil { + w.levels = make(map[zerolog.Level]uint16) + } + w.levels[level]++ + return w.Write(p) +} + +func (w *testWriter) Write(p []byte) (n int, err error) { + if err = json.Unmarshal(p, &w.lastLog); err != nil { + return 0, err + } + return len(p), nil +} + +func TestSeverityHook(t *testing.T) { + // Initialize zerolog with GCP logging requirements + zerolog.TimeFieldFormat = time.RFC3339 + zerolog.TimestampFieldName = logger.GCPFieldKeyTime + zerolog.MessageFieldName = logger.GCPFieldKeyMsg + + // Test writer + tw := &testWriter{} + + // Add the severity hook for GCP logging + var gcpHook logger.SeverityHook + log.Logger = zerolog.New(tw).Hook(gcpHook).With().Timestamp().Logger() + + log.Trace().Msg("just a trace") + require.Equal(t, uint16(1), tw.levels[zerolog.TraceLevel]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeySeverity) + require.Equal(t, "DEBUG", tw.lastLog[logger.GCPFieldKeySeverity]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyMsg) + require.Equal(t, "just a trace", tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyTime) + require.NotEmpty(t, tw.lastLog[logger.GCPFieldKeyMsg]) + + log.Debug().Msg("is it on?") + require.Equal(t, uint16(1), tw.levels[zerolog.DebugLevel]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeySeverity) + require.Equal(t, "DEBUG", tw.lastLog[logger.GCPFieldKeySeverity]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyMsg) + require.Equal(t, "is it on?", tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyTime) + require.NotEmpty(t, tw.lastLog[logger.GCPFieldKeyMsg]) + + log.Info().Str("extra", "foo").Msg("my name is bob") + require.Equal(t, uint16(1), tw.levels[zerolog.InfoLevel]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeySeverity) + require.Equal(t, "INFO", tw.lastLog[logger.GCPFieldKeySeverity]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyMsg) + require.Equal(t, "my name is bob", tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyTime) + require.NotEmpty(t, tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, "extra") + require.Equal(t, "foo", tw.lastLog["extra"]) + + log.Warn().Msg("don't run with scissors") + require.Equal(t, uint16(1), tw.levels[zerolog.WarnLevel]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeySeverity) + require.Equal(t, "WARNING", tw.lastLog[logger.GCPFieldKeySeverity]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyMsg) + require.Equal(t, "don't run with scissors", tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyTime) + require.NotEmpty(t, tw.lastLog[logger.GCPFieldKeyMsg]) + + log.Error().Err(errors.New("bad things")).Msg("oops") + require.Equal(t, uint16(1), tw.levels[zerolog.ErrorLevel]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeySeverity) + require.Equal(t, "ERROR", tw.lastLog[logger.GCPFieldKeySeverity]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyMsg) + require.Equal(t, "oops", tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyTime) + require.NotEmpty(t, tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, "error") + require.Equal(t, "bad things", tw.lastLog["error"]) + + // Must use WithLevel or the program will exit and the test will fail. + log.WithLevel(zerolog.FatalLevel).Err(errors.New("murder")).Msg("dying") + require.Equal(t, uint16(1), tw.levels[zerolog.FatalLevel]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeySeverity) + require.Equal(t, "CRITICAL", tw.lastLog[logger.GCPFieldKeySeverity]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyMsg) + require.Equal(t, "dying", tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyTime) + require.NotEmpty(t, tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, "error") + require.Equal(t, "murder", tw.lastLog["error"]) + + require.Panics(t, func() { + log.Panic().Err(errors.New("run away!")).Msg("squeeeee!!!") + }) + require.Equal(t, uint16(1), tw.levels[zerolog.PanicLevel]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeySeverity) + require.Equal(t, "ALERT", tw.lastLog[logger.GCPFieldKeySeverity]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyMsg) + require.Equal(t, "squeeeee!!!", tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, logger.GCPFieldKeyTime) + require.NotEmpty(t, tw.lastLog[logger.GCPFieldKeyMsg]) + require.Contains(t, tw.lastLog, "error") + require.Equal(t, "run away!", tw.lastLog["error"]) +} diff --git a/pkg/logger/middleware.go b/pkg/logger/middleware.go new file mode 100644 index 0000000..95e5a8e --- /dev/null +++ b/pkg/logger/middleware.go @@ -0,0 +1,69 @@ +package logger + +import ( + "fmt" + "time" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +// GinLogger returns a new Gin middleware that performs logging for our JSON APIs using +// zerolog rather than the default Gin logger which is a standard HTTP logger. +// NOTE: we previously used github.com/dn365/gin-zerolog but wanted more customization. +func GinLogger(server, version string) gin.HandlerFunc { + return func(c *gin.Context) { + // Before request + started := time.Now() + + path := c.Request.URL.Path + if c.Request.URL.RawQuery != "" { + path = path + "?" + c.Request.URL.RawQuery + } + + // Handle the request + c.Next() + + // After request + status := c.Writer.Status() + logctx := log.With(). + Str("path", path). + Str("ser_name", server). + Str("version", version). + Str("method", c.Request.Method). + Dur("resp_time", time.Since(started)). + Int("resp_bytes", c.Writer.Size()). + Int("status", status). + Str("client_ip", c.ClientIP()). + Logger() + + // Log any errors that were added to the context + if len(c.Errors) > 0 { + errs := make([]error, 0, len(c.Errors)) + for _, err := range c.Errors { + errs = append(errs, err) + } + logctx = logctx.With().Errs("errors", errs).Logger() + } + + // Create the message to send to the logger. + var msg string + switch len(c.Errors) { + case 0: + msg = fmt.Sprintf("%s %s %s %d", server, c.Request.Method, c.Request.URL.Path, status) + case 1: + msg = c.Errors.String() + default: + msg = fmt.Sprintf("%s %s %s [%d] %d errors occurred", server, c.Request.Method, c.Request.URL.Path, status, len(c.Errors)) + } + + switch { + case status >= 400 && status < 500: + logctx.Warn().Msg(msg) + case status >= 500: + logctx.Error().Msg(msg) + default: + logctx.Info().Msg(msg) + } + } +} diff --git a/pkg/server.go b/pkg/server.go index a65ef3f..394d32a 100644 --- a/pkg/server.go +++ b/pkg/server.go @@ -11,14 +11,29 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/trisacrypto/courier/pkg/api/v1" "github.com/trisacrypto/courier/pkg/config" + "github.com/trisacrypto/courier/pkg/logger" "github.com/trisacrypto/courier/pkg/store" "github.com/trisacrypto/courier/pkg/store/gcloud" "github.com/trisacrypto/courier/pkg/store/local" ) +func init() { + // Initializes zerolog with our default logging requirements + zerolog.TimeFieldFormat = time.RFC3339 + zerolog.TimestampFieldName = logger.GCPFieldKeyTime + zerolog.MessageFieldName = logger.GCPFieldKeyMsg + zerolog.DurationFieldInteger = false + zerolog.DurationFieldUnit = time.Millisecond + + // Add the severity hook for GCP logging + var gcpHook logger.SeverityHook + log.Logger = zerolog.New(os.Stdout).Hook(gcpHook).With().Timestamp().Logger() +} + // New creates a new server object from configuration but does not serve it yet. func New(conf config.Config) (s *Server, err error) { // Load config from environment if it's empty @@ -28,6 +43,12 @@ func New(conf config.Config) (s *Server, err error) { } } + // Setup our logging config first thing + zerolog.SetGlobalLevel(conf.GetLogLevel()) + if conf.ConsoleLog { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + } + // Create the server object s = &Server{ conf: conf, @@ -35,30 +56,43 @@ func New(conf config.Config) (s *Server, err error) { } // Open the store - switch { - case conf.LocalStorage.Enabled: - if s.store, err = local.Open(conf.LocalStorage); err != nil { - return nil, err - } - case conf.GCPSecretManager.Enabled: - if s.store, err = gcloud.Open(conf.GCPSecretManager); err != nil { - return nil, err + if !s.conf.Maintenance { + switch { + case s.conf.LocalStorage.Enabled: + if s.store, err = local.Open(s.conf.LocalStorage); err != nil { + return nil, err + } + case s.conf.GCPSecretManager.Enabled: + if s.store, err = gcloud.Open(s.conf.GCPSecretManager); err != nil { + return nil, err + } + default: + return nil, errors.New("no storage backend configured") } - default: - return nil, errors.New("no storage backend configured") } // Create the router gin.SetMode(conf.Mode) s.router = gin.New() + s.router.RedirectTrailingSlash = true + s.router.RedirectFixedPath = false + s.router.HandleMethodNotAllowed = true + s.router.ForwardedByClientIP = true + s.router.UseRawPath = false + s.router.UnescapePathValues = true + if err = s.setupRoutes(); err != nil { return nil, err } // Create the http server s.srv = &http.Server{ - Addr: conf.BindAddr, - Handler: s.router, + Addr: conf.BindAddr, + Handler: s.router, + ErrorLog: nil, + ReadHeaderTimeout: 20 * time.Second, + WriteTimeout: 20 * time.Second, + IdleTimeout: 90 * time.Second, } // Use TLS if configured @@ -115,6 +149,7 @@ func (s *Server) Serve() (err error) { } }() + s.SetReady(true) log.Info().Str("listen", s.url).Str("version", Version()).Msg("courier server started") // Wait for shutdown or an error @@ -135,14 +170,18 @@ func (s *Server) Shutdown() (err error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - if err = s.srv.Shutdown(ctx); err != nil { - return err + if serr := s.srv.Shutdown(ctx); serr != nil { + err = errors.Join(err, serr) } - // TODO: Close the stores + if !s.conf.Maintenance { + if serr := s.store.Close(); serr != nil { + err = errors.Join(err, serr) + } + } - log.Debug().Msg("successfully shut down courier server") - return nil + log.Debug().Err(err).Msg("shut down courier server") + return err } // Setup the routes for the courier service. @@ -154,7 +193,7 @@ func (s *Server) setupRoutes() (err error) { s.router.GET("/readyz", s.Readyz) middlewares := []gin.HandlerFunc{ - gin.Logger(), + logger.GinLogger("courier", Version()), gin.Recovery(), s.Available(), } @@ -179,7 +218,6 @@ func (s *Server) setupRoutes() (err error) { // Not found and method not allowed routes s.router.NoRoute(api.NotFound) s.router.NoMethod(api.MethodNotAllowed) - return nil } diff --git a/pkg/status.go b/pkg/status.go index b131e4d..1b121f4 100644 --- a/pkg/status.go +++ b/pkg/status.go @@ -9,8 +9,9 @@ import ( ) const ( - serverStatusOK = "ok" - serverStatusStopping = "stopping" + serverStatusOK = "ok" + serverStatusStopping = "stopping" + serverStatusMaintenance = "maintenance" ) // Status returns the status of the server. @@ -30,21 +31,28 @@ func (s *Server) Status(c *gin.Context) { // http status code if the server is shutting down. This middleware must be first in the // chain to ensure that complex handling to slow the shutdown of the server. func (s *Server) Available() gin.HandlerFunc { + // The server starts in maintenance mode and doesn't change during runtime, so + // determine what the unhealthy status string is going to be prior to the closure. + status := serverStatusStopping + if s.conf.Maintenance { + status = serverStatusMaintenance + } + return func(c *gin.Context) { // Check health status - s.RLock() - if !s.healthy { + if s.conf.Maintenance || !s.IsReady() { c.JSON(http.StatusServiceUnavailable, api.StatusReply{ - Status: serverStatusStopping, + Status: status, Uptime: time.Since(s.started).String(), Version: Version(), }) + // Stop processing the request if the server is not ready c.Abort() - s.RUnlock() return } - s.RUnlock() + + // Continue processing the request c.Next() } }