From 17125e37007dbc35f6bda3ddaa9b9b0ef42f2129 Mon Sep 17 00:00:00 2001 From: giotto Date: Thu, 13 Jun 2024 22:03:51 +0200 Subject: [PATCH] Feat/unittests (#60) - add unit tests - small code improvements --- go.mod | 4 +- go.sum | 4 + main.go | 2 +- pkg/config/config_test.go | 118 +++++++++++ pkg/flags/flags_test.go | 86 ++++++++ pkg/handlers/healthz_test.go | 23 +++ pkg/handlers/heartbeats.go | 6 +- pkg/handlers/heartbeats_test.go | 106 ++++++++++ pkg/handlers/history.go | 4 +- pkg/handlers/history_test.go | 126 ++++++++++++ pkg/handlers/metrics_test.go | 41 ++++ pkg/handlers/ping_test.go | 79 ++++++++ pkg/handlers/utils_test.go | 59 ++++++ pkg/heartbeat/heartbeat_test.go | 189 ++++++++++++++++++ pkg/history/enums.go | 1 - pkg/history/history_test.go | 169 ++++++++++++++++ pkg/metrics/metrics_test.go | 69 +++++++ pkg/notify/notifier/email.go | 20 +- pkg/notify/notifier/msteams.go | 8 +- pkg/notify/notifier/slack.go | 12 +- .../resolve.go => resolver/resolver.go} | 2 +- pkg/notify/resolver/resolver_test.go | 78 ++++++++ pkg/notify/services/slack/slack.go | 6 +- pkg/notify/utils/request_test.go | 74 +++++++ pkg/notify/utils/template.go | 2 +- pkg/notify/utils/template_test.go | 52 +++++ pkg/server/routes.go | 6 +- pkg/server/routes_test.go | 178 +++++++++++++++++ pkg/server/server_test.go | 26 +++ pkg/timer/timer_test.go | 76 +++++++ .../{heartbeat.html => heartbeats.html} | 0 31 files changed, 1589 insertions(+), 37 deletions(-) create mode 100644 pkg/config/config_test.go create mode 100644 pkg/flags/flags_test.go create mode 100644 pkg/handlers/healthz_test.go create mode 100644 pkg/handlers/heartbeats_test.go create mode 100644 pkg/handlers/history_test.go create mode 100644 pkg/handlers/metrics_test.go create mode 100644 pkg/handlers/ping_test.go create mode 100644 pkg/handlers/utils_test.go create mode 100644 pkg/heartbeat/heartbeat_test.go create mode 100644 pkg/history/history_test.go create mode 100644 pkg/metrics/metrics_test.go rename pkg/notify/{resolve/resolve.go => resolver/resolver.go} (99%) create mode 100644 pkg/notify/resolver/resolver_test.go create mode 100644 pkg/notify/utils/request_test.go create mode 100644 pkg/notify/utils/template_test.go create mode 100644 pkg/server/routes_test.go create mode 100644 pkg/server/server_test.go create mode 100644 pkg/timer/timer_test.go rename web/templates/{heartbeat.html => heartbeats.html} (100%) diff --git a/go.mod b/go.mod index d9016c7..022ed08 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,9 @@ require ( github.com/Masterminds/sprig v2.22.0+incompatible github.com/prometheus/client_golang v1.19.1 github.com/sirupsen/logrus v1.9.3 + github.com/spf13/afero v1.11.0 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.9.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -26,8 +28,8 @@ require ( github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect - github.com/stretchr/testify v1.9.0 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.33.0 // indirect ) diff --git a/go.sum b/go.sum index 6cf90c4..77c2d40 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -55,6 +57,8 @@ golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOM golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index a2777be..0462028 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,7 @@ import ( "os" ) -const version = "0.6.8" +const version = "0.6.7" //go:embed web var templates embed.FS diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..2467cd9 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,118 @@ +package config + +import ( + "heartbeats/pkg/heartbeat" + "heartbeats/pkg/history" + "heartbeats/pkg/notify" + "heartbeats/pkg/notify/notifier" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +// Sample configuration YAML for testing. +const sampleConfig = ` +version: "1.0.0" +verbose: true +path: "./config.yaml" +server: + siteRoot: "http://localhost:8080" + listenAddress: "localhost:8080" +cache: + maxSize: 100 + reduce: 10 +notifications: + slack: + type: "slack" + slack_config: + channel: "general" +heartbeats: + heartbeat1: + name: "heartbeat1" + interval: "1m" + grace: "1m" + notifications: + - slack +` + +func writeSampleConfig(t *testing.T, content string) string { + file, err := os.CreateTemp("", "config*.yaml") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer file.Close() + + if _, err := file.WriteString(content); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + + return file.Name() +} + +func TestConfig_Read(t *testing.T) { + App.NotificationStore = notify.NewStore() + HistoryStore = history.NewStore() + + tempFile := writeSampleConfig(t, sampleConfig) + defer os.Remove(tempFile) + + App.Path = tempFile + + err := App.Read() + assert.NoError(t, err, "Expected no error when reading the config file") + + notification := App.NotificationStore.Get("slack") + assert.NotNil(t, notification, "Expected slack notification to be present") + assert.Equal(t, `{{ if eq .Status "ok" }}good{{ else }}danger{{ end }}`, notification.SlackConfig.ColorTemplate) + + heartbeat := App.HeartbeatStore.Get("heartbeat1") + assert.NotNil(t, heartbeat, "Expected heartbeat1 to be present") + assert.Equal(t, "heartbeat1", heartbeat.Name) +} + +func TestProcessNotifications(t *testing.T) { + App.NotificationStore = notify.NewStore() + HistoryStore = history.NewStore() + + var rawConfig map[string]interface{} + err := yaml.Unmarshal([]byte(sampleConfig), &rawConfig) + assert.NoError(t, err) + err = App.processNotifications(rawConfig["notifications"]) + assert.NoError(t, err, "Expected no error when processing notifications") + + notification := App.NotificationStore.Get("slack") + assert.NotNil(t, notification, "Expected slack notification to be present") + assert.Equal(t, "slack", notification.Type) + assert.Equal(t, `{{ if eq .Status "ok" }}good{{ else }}danger{{ end }}`, notification.SlackConfig.ColorTemplate) +} + +func TestProcessHeartbeats(t *testing.T) { + App.HeartbeatStore = heartbeat.NewStore() + HistoryStore = history.NewStore() + + var rawConfig map[string]interface{} + err := yaml.Unmarshal([]byte(sampleConfig), &rawConfig) + assert.NoError(t, err) + + err = App.processHeartbeats(rawConfig["heartbeats"]) + assert.NoError(t, err, "Expected no error when processing heartbeats") + + heartbeat := App.HeartbeatStore.Get("heartbeat1") + assert.NotNil(t, heartbeat, "Expected heartbeat1 to be present") + assert.Equal(t, "heartbeat1", heartbeat.Name) +} + +func TestUpdateSlackNotification(t *testing.T) { + notification := ¬ify.Notification{ + Type: "slack", + SlackConfig: ¬ifier.SlackConfig{ + Channel: "general", + }, + } + + err := App.updateSlackNotification("slack", notification) + assert.NoError(t, err, "Expected no error when updating slack notification") + assert.Equal(t, `{{ if eq .Status "ok" }}good{{ else }}danger{{ end }}`, notification.SlackConfig.ColorTemplate) +} diff --git a/pkg/flags/flags_test.go b/pkg/flags/flags_test.go new file mode 100644 index 0000000..f799a4a --- /dev/null +++ b/pkg/flags/flags_test.go @@ -0,0 +1,86 @@ +package flags + +import ( + "heartbeats/pkg/config" + "os" + "strings" + "testing" + + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" +) + +func resetFlags() { + pflag.CommandLine = pflag.NewFlagSet("heartbeats", pflag.ExitOnError) +} + +func TestParseFlags(t *testing.T) { + resetFlags() + + output := &strings.Builder{} + args := []string{"cmd", "-c", "config.yaml", "-l", "127.0.0.1:9090", "-s", "http://example.com", "-m", "200", "-r", "20", "-v"} + + result := ParseFlags(args, output) + assert.NoError(t, result.Err) + assert.Equal(t, result.ShowVersion, false) + assert.Equal(t, result.ShowHelp, false) + + assert.Equal(t, "config.yaml", config.App.Path) + assert.Equal(t, "127.0.0.1:9090", config.App.Server.ListenAddress) + assert.Equal(t, "http://example.com", config.App.Server.SiteRoot) + assert.Equal(t, 200, config.App.Cache.MaxSize) + assert.Equal(t, 20, config.App.Cache.Reduce) + assert.True(t, config.App.Verbose) +} + +func TestShowVersionFlag(t *testing.T) { + resetFlags() + + output := &strings.Builder{} + args := []string{"cmd", "--version"} + result := ParseFlags(args, output) + assert.NoError(t, result.Err) + assert.Equal(t, result.ShowVersion, true) + assert.Equal(t, result.ShowHelp, false) +} + +func TestShowHelpFlag(t *testing.T) { + resetFlags() + + output := &strings.Builder{} + args := []string{"cmd", "--help"} + + result := ParseFlags(args, output) + assert.NoError(t, result.Err) + assert.Equal(t, result.ShowVersion, false) + assert.Equal(t, result.ShowHelp, true) +} + +func TestProcessEnvVariables(t *testing.T) { + resetFlags() + + os.Setenv("HEARTBEATS_CONFIG", "env_config.yaml") + os.Setenv("HEARTBEATS_LISTEN_ADDRESS", "0.0.0.0:8080") + os.Setenv("HEARTBEATS_SITE_ROOT", "http://env.com") + os.Setenv("HEARTBEATS_MAX_SIZE", "300") + os.Setenv("HEARTBEATS_REDUCE", "30") + os.Setenv("HEARTBEATS_VERBOSE", "true") + + pflag.StringVarP(&config.App.Path, "config", "c", "./deploy/config.yaml", "Path to the configuration file") + pflag.StringVarP(&config.App.Server.ListenAddress, "listen-address", "l", "localhost:8080", "Address to listen on") + pflag.StringVarP(&config.App.Server.SiteRoot, "site-root", "s", "", "Site root for the heartbeat service (default \"http://\")") + pflag.IntVarP(&config.App.Cache.MaxSize, "max-size", "m", 100, "Maximum size of the cache") + pflag.IntVarP(&config.App.Cache.Reduce, "reduce", "r", 10, "Amount to reduce when max size is exceeded") + pflag.BoolVarP(&config.App.Verbose, "verbose", "v", false, "Enable verbose logging") + + pflag.Parse() + + processEnvVariables() + + assert.Equal(t, "env_config.yaml", config.App.Path) + assert.Equal(t, "0.0.0.0:8080", config.App.Server.ListenAddress) + assert.Equal(t, "http://env.com", config.App.Server.SiteRoot) + assert.Equal(t, 300, config.App.Cache.MaxSize) + assert.Equal(t, 30, config.App.Cache.Reduce) + assert.True(t, config.App.Verbose) +} diff --git a/pkg/handlers/healthz_test.go b/pkg/handlers/healthz_test.go new file mode 100644 index 0000000..1d4aad7 --- /dev/null +++ b/pkg/handlers/healthz_test.go @@ -0,0 +1,23 @@ +package handlers + +import ( + "heartbeats/pkg/logger" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHealthz(t *testing.T) { + log := logger.NewLogger(true) + handler := Healthz(log) + + req := httptest.NewRequest("GET", "/healthz", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, "Expected status code 200") + assert.Equal(t, "ok", rec.Body.String(), "Expected response body 'ok'") +} diff --git a/pkg/handlers/heartbeats.go b/pkg/handlers/heartbeats.go index 359d352..bb27baf 100644 --- a/pkg/handlers/heartbeats.go +++ b/pkg/handlers/heartbeats.go @@ -1,11 +1,11 @@ package handlers import ( - "embed" "heartbeats/pkg/config" "heartbeats/pkg/logger" "heartbeats/pkg/timer" "html/template" + "io/fs" "net/http" "time" @@ -37,7 +37,7 @@ type NotificationState struct { } // Heartbeats handles the / endpoint -func Heartbeats(logger logger.Logger, staticFS embed.FS) http.HandlerFunc { +func Heartbeats(logger logger.Logger, staticFS fs.FS) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { fmap := sprig.TxtFuncMap() fmap["isTrue"] = isTrue @@ -48,7 +48,7 @@ func Heartbeats(logger logger.Logger, staticFS embed.FS) http.HandlerFunc { Funcs(fmap). ParseFS( staticFS, - "web/templates/heartbeat.html", + "web/templates/heartbeats.html", "web/templates/footer.html", ) if err != nil { diff --git a/pkg/handlers/heartbeats_test.go b/pkg/handlers/heartbeats_test.go new file mode 100644 index 0000000..5ee215d --- /dev/null +++ b/pkg/handlers/heartbeats_test.go @@ -0,0 +1,106 @@ +package handlers + +import ( + "heartbeats/pkg/config" + "heartbeats/pkg/heartbeat" + "heartbeats/pkg/history" + "heartbeats/pkg/logger" + "heartbeats/pkg/notify" + "heartbeats/pkg/notify/notifier" + "heartbeats/pkg/timer" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" +) + +func setupAferoFSForHeartbeats() afero.Fs { + aferoFS := afero.NewMemMapFs() + + templateFiles := []string{ + "heartbeats.html", + "footer.html", + } + + for _, file := range templateFiles { + content, err := os.ReadFile(filepath.Join("../../web/templates", file)) + if err != nil { + panic(err) + } + + err = afero.WriteFile(aferoFS, "web/templates/"+file, content, 0644) + if err != nil { + panic(err) + } + } + + return aferoFS +} + +func TestHeartbeatsHandler(t *testing.T) { + log := logger.NewLogger(true) + config.App.HeartbeatStore = heartbeat.NewStore() + config.App.NotificationStore = notify.NewStore() + config.HistoryStore = history.NewStore() + + h := &heartbeat.Heartbeat{ + Name: "test", + Enabled: new(bool), + Interval: &timer.Timer{Interval: new(time.Duration)}, + Grace: &timer.Timer{Interval: new(time.Duration)}, + Notifications: []string{"test"}, + } + *h.Enabled = true + *h.Interval.Interval = time.Minute + *h.Grace.Interval = time.Minute + + err := config.App.HeartbeatStore.Add("test", h) + assert.NoError(t, err) + + ns := ¬ify.Notification{ + Name: "test", + Type: "email", + Enabled: new(bool), + MailConfig: ¬ifier.MailConfig{}, + } + *ns.Enabled = false + + err = config.App.NotificationStore.Add("test", ns) + assert.NoError(t, err) + + aferoFS := setupAferoFSForHeartbeats() + + mux := http.NewServeMux() + mux.Handle("/", Heartbeats(log, aferoToCustomAferoFS(aferoFS))) + + t.Run("Heartbeat page renders correctly", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, "Expected status code 200") + assert.Contains(t, rec.Body.String(), "Heartbeat", "Expected 'Heartbeat' in response body") + assert.Contains(t, rec.Body.String(), "test", "Expected 'test' in response body") + }) + + // Simulate a template parsing error by using an invalid template path + t.Run("Template parsing error", func(t *testing.T) { + invalidFS := afero.NewMemMapFs() // Empty FS to simulate missing templates + mux := http.NewServeMux() + mux.Handle("/", Heartbeats(log, aferoToCustomAferoFS(invalidFS))) + + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code, "Expected status code 500") + assert.Contains(t, rec.Body.String(), "Internal Server Error", "Expected internal server error message") + }) +} diff --git a/pkg/handlers/history.go b/pkg/handlers/history.go index bc93156..256e456 100644 --- a/pkg/handlers/history.go +++ b/pkg/handlers/history.go @@ -1,17 +1,17 @@ package handlers import ( - "embed" "fmt" "heartbeats/pkg/config" "heartbeats/pkg/history" "heartbeats/pkg/logger" "html/template" + "io/fs" "net/http" ) // History handles the /history/{id} endpoint -func History(logger logger.Logger, staticFS embed.FS) http.Handler { +func History(logger logger.Logger, staticFS fs.FS) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { heartbeatName := r.PathValue("id") logger.Debugf("%s /history/%s %s %s", r.Method, heartbeatName, r.RemoteAddr, r.UserAgent()) diff --git a/pkg/handlers/history_test.go b/pkg/handlers/history_test.go new file mode 100644 index 0000000..cfc5be9 --- /dev/null +++ b/pkg/handlers/history_test.go @@ -0,0 +1,126 @@ +package handlers + +import ( + "heartbeats/pkg/config" + "heartbeats/pkg/heartbeat" + "heartbeats/pkg/history" + "heartbeats/pkg/logger" + "heartbeats/pkg/notify" + "heartbeats/pkg/timer" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" +) + +func setupAferoFSForHistory() afero.Fs { + aferoFS := afero.NewMemMapFs() + + templateFiles := []string{ + "history.html", + "footer.html", + } + + for _, file := range templateFiles { + content, err := os.ReadFile(filepath.Join("../../web/templates", file)) + if err != nil { + panic(err) + } + + err = afero.WriteFile(aferoFS, "web/templates/"+file, content, 0644) + if err != nil { + panic(err) + } + } + + return aferoFS +} + +// customAferoFS implements the fs.FS interface for afero.Fs +type customAferoFS struct { + fs afero.Fs +} + +// Open implements the fs.FS interface +func (a *customAferoFS) Open(name string) (fs.File, error) { + return a.fs.Open(name) +} + +// Convert the afero.Fs to customAferoFS +func aferoToCustomAferoFS(afs afero.Fs) fs.FS { + return &customAferoFS{fs: afs} +} + +func TestHistoryHandler(t *testing.T) { + log := logger.NewLogger(true) + config.App.HeartbeatStore = heartbeat.NewStore() + config.App.NotificationStore = notify.NewStore() + config.HistoryStore = history.NewStore() + + h := &heartbeat.Heartbeat{ + Name: "test", + Enabled: new(bool), + Interval: &timer.Timer{Interval: new(time.Duration)}, + Grace: &timer.Timer{Interval: new(time.Duration)}, + } + *h.Enabled = true + *h.Interval.Interval = time.Minute + *h.Grace.Interval = time.Minute + + err := config.App.HeartbeatStore.Add("test", h) + assert.NoError(t, err) + + hist, err := history.NewHistory(10, 2) + assert.NoError(t, err) + + err = config.HistoryStore.Add("test", hist) + assert.NoError(t, err) + + ns := notify.NewStore() + config.App.NotificationStore = ns + + mux := http.NewServeMux() + aferoFS := setupAferoFSForHistory() + mux.Handle("GET /history/{id}", History(log, aferoToCustomAferoFS(aferoFS))) + + t.Run("Heartbeat not found", func(t *testing.T) { + req := httptest.NewRequest("GET", "/history/nonexistent", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code, "Expected status code 404") + assert.Contains(t, rec.Body.String(), "Heartbeat 'nonexistent' not found", "Expected heartbeat not found message") + }) + + t.Run("Heartbeat found and history retrieved", func(t *testing.T) { + req := httptest.NewRequest("GET", "/history/test", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, "Expected status code 200") + assert.Contains(t, rec.Body.String(), "History for test", "Expected history content") + }) + + // Simulate a template parsing error by using an invalid template path + t.Run("Template parsing error", func(t *testing.T) { + invalidFS := afero.NewMemMapFs() // Empty FS to simulate missing templates + mux := http.NewServeMux() + mux.Handle("GET /history/{id}", History(log, aferoToCustomAferoFS(invalidFS))) + + req := httptest.NewRequest("GET", "/history/test", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code, "Expected status code 500") + assert.Contains(t, rec.Body.String(), "Internal Server Error", "Expected internal server error message") + }) +} diff --git a/pkg/handlers/metrics_test.go b/pkg/handlers/metrics_test.go new file mode 100644 index 0000000..19ad6c7 --- /dev/null +++ b/pkg/handlers/metrics_test.go @@ -0,0 +1,41 @@ +package handlers + +import ( + "heartbeats/pkg/logger" + "heartbeats/pkg/metrics" + "net/http" + "net/http/httptest" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +// TestMetricsHandler tests the Metrics handler. +func TestMetricsHandler(t *testing.T) { + log := logger.NewLogger(true) + + // Register a test metric + testMetric := prometheus.NewCounter(prometheus.CounterOpts{ + Name: "test_metric", + Help: "This is a test metric", + }) + metrics.PromMetrics.Registry.MustRegister(testMetric) + + // Increment the test metric + testMetric.Inc() + + // Create the Metrics handler + handler := Metrics(log) + + // Create a new HTTP request + req := httptest.NewRequest("GET", "/metrics", nil) + rec := httptest.NewRecorder() + + // Serve the HTTP request + handler.ServeHTTP(rec, req) + + // Check the status code and response body + assert.Equal(t, http.StatusOK, rec.Code, "Expected status code 200") + assert.Contains(t, rec.Body.String(), "test_metric", "Expected response body to contain 'test_metric'") +} diff --git a/pkg/handlers/ping_test.go b/pkg/handlers/ping_test.go new file mode 100644 index 0000000..638196d --- /dev/null +++ b/pkg/handlers/ping_test.go @@ -0,0 +1,79 @@ +package handlers + +import ( + "heartbeats/pkg/config" + "heartbeats/pkg/heartbeat" + "heartbeats/pkg/history" + "heartbeats/pkg/logger" + "heartbeats/pkg/notify" + "heartbeats/pkg/timer" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestPingHandler(t *testing.T) { + log := logger.NewLogger(true) + config.App.HeartbeatStore = heartbeat.NewStore() + config.App.NotificationStore = notify.NewStore() + config.HistoryStore = history.NewStore() + + h := &heartbeat.Heartbeat{ + Name: "test", + Enabled: new(bool), + Interval: &timer.Timer{Interval: new(time.Duration)}, + Grace: &timer.Timer{Interval: new(time.Duration)}, + } + *h.Enabled = true + *h.Interval.Interval = time.Minute + *h.Grace.Interval = time.Minute + + err := config.App.HeartbeatStore.Add("test", h) + assert.NoError(t, err) + + hist, err := history.NewHistory(10, 2) + assert.NoError(t, err) + + err = config.HistoryStore.Add("test", hist) + assert.NoError(t, err) + + ns := notify.NewStore() + config.App.NotificationStore = ns + + mux := http.NewServeMux() + mux.Handle("GET /ping/{id}", Ping(log)) + + t.Run("Heartbeat found and enabled", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ping/test", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, "Expected status code 200") + assert.Equal(t, "ok", rec.Body.String(), "Expected response body 'ok'") + }) + + t.Run("Heartbeat not found", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ping/nonexistent", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code, "Expected status code 404") + assert.Contains(t, rec.Body.String(), "Heartbeat 'nonexistent' not found", "Expected heartbeat not found message") + }) + + t.Run("Heartbeat found but not enabled", func(t *testing.T) { + *h.Enabled = false + req := httptest.NewRequest("GET", "/ping/test", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code, "Expected status code 503") + assert.Contains(t, rec.Body.String(), "Heartbeat 'test' not enabled", "Expected heartbeat not enabled message") + }) +} diff --git a/pkg/handlers/utils_test.go b/pkg/handlers/utils_test.go new file mode 100644 index 0000000..efbb1ff --- /dev/null +++ b/pkg/handlers/utils_test.go @@ -0,0 +1,59 @@ +package handlers + +import ( + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestIsFalse(t *testing.T) { + bTrue := true + bFalse := false + + assert.True(t, isFalse(&bFalse), "Expected isFalse to return true for false pointer") + assert.False(t, isFalse(&bTrue), "Expected isFalse to return false for true pointer") + assert.False(t, isFalse(nil), "Expected isFalse to return false for nil pointer") +} + +func TestIsTrue(t *testing.T) { + bTrue := true + bFalse := false + + assert.False(t, isTrue(&bFalse), "Expected isTrue to return false for false pointer") + assert.True(t, isTrue(&bTrue), "Expected isTrue to return true for true pointer") + assert.False(t, isTrue(nil), "Expected isTrue to return false for nil pointer") +} + +func TestFormatTime(t *testing.T) { + format := "2006-01-02 15:04:05" + + t.Run("Non-zero time", func(t *testing.T) { + tm := time.Date(2021, 9, 15, 14, 0, 0, 0, time.UTC) + expected := "2021-09-15 14:00:00" + assert.Equal(t, expected, formatTime(tm, format), "Expected formatted time to match") + }) + + t.Run("Zero time", func(t *testing.T) { + tm := time.Time{} + expected := "-" + assert.Equal(t, expected, formatTime(tm, format), "Expected formatted time to be '-' for zero time") + }) +} + +func TestGetClientIP(t *testing.T) { + t.Run("With X-Forwarded-For header", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-For", "192.168.1.1, 10.0.0.1") + expected := "192.168.1.1" + assert.Equal(t, expected, getClientIP(req), "Expected to return the first IP from X-Forwarded-For header") + }) + + t.Run("Without X-Forwarded-For header", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "127.0.0.1:8080" + expected := "127.0.0.1:8080" + assert.Equal(t, expected, getClientIP(req), "Expected to return RemoteAddr when X-Forwarded-For header is absent") + }) +} diff --git a/pkg/heartbeat/heartbeat_test.go b/pkg/heartbeat/heartbeat_test.go new file mode 100644 index 0000000..60fc2d3 --- /dev/null +++ b/pkg/heartbeat/heartbeat_test.go @@ -0,0 +1,189 @@ +package heartbeat + +import ( + "context" + "heartbeats/pkg/history" + "heartbeats/pkg/logger" + "heartbeats/pkg/notify" + "heartbeats/pkg/timer" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStore(t *testing.T) { + store := NewStore() + + interval := time.Second * 3 + grace := time.Second * 5 + tm := timer.Timer{Interval: &interval} + gr := timer.Timer{Interval: &grace} + + h := &Heartbeat{ + Name: "test", + Interval: &tm, + Grace: &gr, + } + + t.Run("Add", func(t *testing.T) { + err := store.Add("test", h) + assert.NoError(t, err, "Expected no error when adding a heartbeat") + }) + + t.Run("Add duplicate", func(t *testing.T) { + err := store.Add("test", h) + assert.Error(t, err, "Expected error when adding a duplicate heartbeat") + }) + + t.Run("Get All", func(t *testing.T) { + all := store.GetAll() + assert.Equal(t, 1, len(all), "Expected one heartbeat in store") + }) + + t.Run("Get", func(t *testing.T) { + retrieved := store.Get("test") + assert.NotNil(t, retrieved, "Expected to retrieve the added heartbeat") + }) + + t.Run("Delete", func(t *testing.T) { + store.Delete("test") + retrieved := store.Get("test") + assert.Nil(t, retrieved, "Expected heartbeat to be deleted") + }) +} + +func TestHeartbeatTimers(t *testing.T) { + log := logger.NewLogger(true) + hist, err := history.NewHistory(20, 20) + assert.NoError(t, err) + + ns := notify.NewStore() + + interval := time.Second * 2 + grace := time.Second * 2 + tm := timer.Timer{Interval: &interval} + gr := timer.Timer{Interval: &grace} + + h := &Heartbeat{ + Name: "test", + Interval: &tm, + Grace: &gr, + } + + ctx := context.Background() + + t.Run("StartInterval", func(t *testing.T) { + h.StartInterval(ctx, log, ns, hist) + assert.Equal(t, StatusOK.String(), h.Status, "Expected status to be OK after starting interval") + }) + + t.Run("Multiple StartTimer with sleep", func(t *testing.T) { + h.StartInterval(ctx, log, ns, hist) + assert.Equal(t, StatusOK.String(), h.Status, "Expected status to be OK after starting interval") + time.Sleep(1 * time.Second) + h.StartInterval(ctx, log, ns, hist) + assert.Equal(t, StatusOK.String(), h.Status, "Expected status to be OK after starting interval") + time.Sleep(1 * time.Second) + h.StartInterval(ctx, log, ns, hist) + assert.Equal(t, StatusOK.String(), h.Status, "Expected status to be OK after starting interval") + }) + + t.Run("Multiple StartTimer without sleep", func(t *testing.T) { + h.StartInterval(ctx, log, ns, hist) + assert.Equal(t, StatusOK.String(), h.Status, "Expected status to be OK after starting interval") + h.StartInterval(ctx, log, ns, hist) + assert.Equal(t, StatusOK.String(), h.Status, "Expected status to be OK after starting interval") + h.StartInterval(ctx, log, ns, hist) + assert.Equal(t, StatusOK.String(), h.Status, "Expected status to be OK after starting interval") + }) + + t.Run("GraceAfterInterval", func(t *testing.T) { + h.StartInterval(ctx, log, ns, hist) + time.Sleep(3 * time.Second) // wait for the interval to elapse + assert.Equal(t, StatusGrace.String(), h.Status, "Expected status to be GRACE after interval elapsed") + }) + + t.Run("StartGrace", func(t *testing.T) { + h.StartGrace(ctx, log, ns, hist) + assert.Equal(t, StatusGrace.String(), h.Status, "Expected status to be GRACE after starting grace") + }) + + t.Run("GraceToNOK", func(t *testing.T) { + h.StartGrace(ctx, log, ns, hist) + time.Sleep(5 * time.Second) // wait for the grace to elapse + assert.Equal(t, StatusNOK.String(), h.Status, "Expected status to be NOK after grace elapsed") + }) + + t.Run("StopTimer", func(t *testing.T) { + h.StopTimers() + assert.Nil(t, h.Interval.Timer, "Expected interval timer to be stopped") + assert.Nil(t, h.Grace.Timer, "Expected grace timer to be stopped") + }) +} + +func TestHeartbeatUpdateStatus(t *testing.T) { + log := logger.NewLogger(true) + hist, err := history.NewHistory(10, 2) + assert.NoError(t, err) + + ns := notify.NewStore() + + interval := time.Second * 2 + grace := time.Second * 2 + tm := timer.Timer{Interval: &interval} + gr := timer.Timer{Interval: &grace} + + h := &Heartbeat{ + Name: "test", + Interval: &tm, + Grace: &gr, + } + + ctx := context.Background() + + // Test UpdateStatus + t.Run("UpdateStatus", func(t *testing.T) { + h.updateStatus(ctx, log, StatusOK, ns, hist) + assert.Equal(t, StatusOK.String(), h.Status, "Expected status to be updated to OK") + assert.False(t, h.LastPing.IsZero(), "Expected LastPing to be updated") + }) +} + +func TestSendNotifications(t *testing.T) { + log := logger.NewLogger(true) + hist, err := history.NewHistory(10, 2) + assert.NoError(t, err) + + ns := notify.NewStore() + + interval := time.Second * 2 + grace := time.Second * 2 + tm := timer.Timer{Interval: &interval} + gr := timer.Timer{Interval: &grace} + + h := &Heartbeat{ + Name: "test", + Interval: &tm, + Grace: &gr, + Notifications: []string{"test-notification"}, + } + + ctx := context.Background() + + notification := ¬ify.Notification{ + Name: "test-notification", + Enabled: boolPtr(true), + } + + _ = ns.Add("test-notification", notification) + + t.Run("SendNotifications", func(t *testing.T) { + h.SendNotifications(ctx, log, ns, hist, false) + // assert.NotEmpty(t, hist.GetAllEntries(), "Expected notifications to be sent and logged in history") + }) +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/pkg/history/enums.go b/pkg/history/enums.go index b142a61..4205375 100644 --- a/pkg/history/enums.go +++ b/pkg/history/enums.go @@ -1,6 +1,5 @@ package history -// Event type represents various events that can be logged in history. type Event int16 const ( diff --git a/pkg/history/history_test.go b/pkg/history/history_test.go new file mode 100644 index 0000000..b5fe7ab --- /dev/null +++ b/pkg/history/history_test.go @@ -0,0 +1,169 @@ +package history + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestHistory_NewHistory(t *testing.T) { + _, err := NewHistory(5, 120) + assert.Error(t, err) +} + +func TestHistory_AddEntry(t *testing.T) { + t.Run("Add Entries", func(t *testing.T) { + h, err := NewHistory(5, 20) + assert.NoError(t, err) + + h.AddEntry(Beat, "Beat message", nil) + assert.Equal(t, 1, len(h.GetAllEntries())) + + h.AddEntry(Interval, "Interval message", nil) + h.AddEntry(Grace, "Grace message", nil) + h.AddEntry(Expired, "Expired message", nil) + h.AddEntry(Send, "Send message", nil) + assert.Equal(t, 5, len(h.GetAllEntries())) + + h.AddEntry(Beat, "New Beat message", nil) + assert.Equal(t, 4, len(h.GetAllEntries())) // Reduced by 20% (5 * 0.8 = 4) + }) + + t.Run("Add Entries, reduce odd", func(t *testing.T) { + h, err := NewHistory(5, 25) + assert.NoError(t, err) + + h.AddEntry(Beat, "Beat message", nil) + assert.Equal(t, 1, len(h.GetAllEntries())) + + h.AddEntry(Interval, "Interval message", nil) + h.AddEntry(Grace, "Grace message", nil) + h.AddEntry(Expired, "Expired message", nil) + h.AddEntry(Send, "Send message", nil) + assert.Equal(t, 5, len(h.GetAllEntries())) + + h.AddEntry(Beat, "New Beat message", nil) + assert.Equal(t, 4, len(h.GetAllEntries())) // Reduced by 25% (5 * 0.75 ~= 4) + }) +} + +func TestHistory_GetAllEntries(t *testing.T) { + h, err := NewHistory(5, 20) + assert.NoError(t, err) + + t.Run("AddEntries", func(t *testing.T) { + h.AddEntry(Beat, "Beat message", nil) + h.AddEntry(Interval, "Interval message", nil) + }) + + t.Run("VerifyEntries", func(t *testing.T) { + entries := h.GetAllEntries() + assert.Equal(t, 2, len(entries), "Expected 2 entries") + assert.Equal(t, Beat, entries[0].Event, "Expected first entry to be a Beat event") + assert.Equal(t, "Beat message", entries[0].Message, "Expected first entry message to be 'Beat message'") + assert.Equal(t, Interval, entries[1].Event, "Expected second entry to be an Interval event") + assert.Equal(t, "Interval message", entries[1].Message, "Expected second entry message to be 'Interval message'") + }) +} + +func TestStore_Add_Get_Delete(t *testing.T) { + store := NewStore() + h, err := NewHistory(5, 20) + assert.NoError(t, err) + + t.Run("Add", func(t *testing.T) { + err = store.Add("test", h) + assert.NoError(t, err, "Expected no error when adding a history") + }) + + t.Run("Get", func(t *testing.T) { + retrieved := store.Get("test") + assert.NotNil(t, retrieved, "Expected to retrieve the added history") + }) + + t.Run("Delete", func(t *testing.T) { + store.Delete("test") + retrieved := store.Get("test") + assert.Nil(t, retrieved, "Expected history to be deleted") + }) +} + +func TestStore_AddDuplicate(t *testing.T) { + store := NewStore() + h, err := NewHistory(5, 20) + assert.NoError(t, err) + + t.Run("Add", func(t *testing.T) { + err = store.Add("test", h) + assert.NoError(t, err) + }) + + t.Run("Duplicate", func(t *testing.T) { + err = store.Add("test", h) + assert.Error(t, err) + }) +} + +func TestEvent_String(t *testing.T) { + t.Run("TestBeat", func(t *testing.T) { + assert.Equal(t, "BEAT", Beat.String()) + }) + + t.Run("TestInterval", func(t *testing.T) { + assert.Equal(t, "INTERVAL", Interval.String()) + }) + + t.Run("TestGrace", func(t *testing.T) { + assert.Equal(t, "GRACE", Grace.String()) + }) + + t.Run("TestExpired", func(t *testing.T) { + assert.Equal(t, "EXPIRED", Expired.String()) + }) + + t.Run("TestSend", func(t *testing.T) { + assert.Equal(t, "SEND", Send.String()) + }) +} + +func TestHistoryEntry(t *testing.T) { + entry := HistoryEntry{ + Time: time.Now(), + Event: Beat, + Message: "Test message", + Details: map[string]string{"key": "value"}, + } + + t.Run("TestEvent", func(t *testing.T) { + assert.Equal(t, Beat, entry.Event) + }) + + t.Run("TestMessage", func(t *testing.T) { + assert.Equal(t, "Test message", entry.Message) + }) + + t.Run("TestDetails", func(t *testing.T) { + assert.Equal(t, "value", entry.Details["key"]) + }) +} + +func TestMarshalYAML(t *testing.T) { + store := NewStore() + h, err := NewHistory(5, 20) + + t.Run("TestNewHistory", func(t *testing.T) { + assert.NoError(t, err) + }) + + t.Run("TestAddHistory", func(t *testing.T) { + err = store.Add("test", h) + assert.NoError(t, err) + }) + + t.Run("TestMarshalStore", func(t *testing.T) { + data, err := store.MarshalYAML() + assert.NoError(t, err) + assert.NotNil(t, data) + }) +} diff --git a/pkg/metrics/metrics_test.go b/pkg/metrics/metrics_test.go new file mode 100644 index 0000000..bb0b648 --- /dev/null +++ b/pkg/metrics/metrics_test.go @@ -0,0 +1,69 @@ +package metrics + +import ( + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" +) + +func TestNewMetrics(t *testing.T) { + PromMetrics = *NewMetrics() + assert.NotNil(t, PromMetrics.Registry, "Registry should not be nil") + + // Initialize metrics + HeartbeatStatus.With(prometheus.Labels{"heartbeat": "test_heartbeat"}) + TotalHeartbeats.With(prometheus.Labels{"heartbeat": "test_heartbeat"}) + + gatherers := prometheus.Gatherers{ + PromMetrics.Registry, + } + + mfs, err := gatherers.Gather() + assert.NoError(t, err, "Expected no error while gathering metrics") + + foundHeartbeatStatus := false + foundTotalHeartbeats := false + + for _, mf := range mfs { + if *mf.Name == "heartbeats_heartbeat_last_status" { + foundHeartbeatStatus = true + } + if *mf.Name == "heartbeats_heartbeats_total" { + foundTotalHeartbeats = true + } + } + + assert.True(t, foundHeartbeatStatus, "Expected to find heartbeats_heartbeat_last_status metric") + assert.True(t, foundTotalHeartbeats, "Expected to find heartbeats_heartbeats_total metric") +} + +func TestHeartbeatStatusMetric(t *testing.T) { + PromMetrics = *NewMetrics() + + HeartbeatStatus.With(prometheus.Labels{"heartbeat": "test_heartbeat"}).Set(UP) + + expected := ` + # HELP heartbeats_heartbeat_last_status Total number of heartbeats + # TYPE heartbeats_heartbeat_last_status gauge + heartbeats_heartbeat_last_status{heartbeat="test_heartbeat"} 1 + ` + err := testutil.GatherAndCompare(PromMetrics.Registry, strings.NewReader(expected), "heartbeats_heartbeat_last_status") + assert.NoError(t, err, "Expected no error while gathering and comparing metrics") +} + +func TestTotalHeartbeatsMetric(t *testing.T) { + PromMetrics = *NewMetrics() + + TotalHeartbeats.With(prometheus.Labels{"heartbeat": "test_heartbeat"}).Inc() + + expected := ` + # HELP heartbeats_heartbeats_total The total number of heartbeats + # TYPE heartbeats_heartbeats_total counter + heartbeats_heartbeats_total{heartbeat="test_heartbeat"} 1 + ` + err := testutil.GatherAndCompare(PromMetrics.Registry, strings.NewReader(expected), "heartbeats_heartbeats_total") + assert.NoError(t, err, "Expected no error while gathering and comparing metrics") +} diff --git a/pkg/notify/notifier/email.go b/pkg/notify/notifier/email.go index 48f1be7..d266aa6 100644 --- a/pkg/notify/notifier/email.go +++ b/pkg/notify/notifier/email.go @@ -3,7 +3,7 @@ package notifier import ( "context" "fmt" - "heartbeats/pkg/notify/resolve" + "heartbeats/pkg/notify/resolver" "heartbeats/pkg/notify/services/email" "time" ) @@ -88,19 +88,19 @@ func (e EmailNotifier) CheckResolveVariables() error { // - error: An error if resolving any field fails. func resolveSMTPConfig(config email.SMTPConfig) (email.SMTPConfig, error) { var err error - config.Host, err = resolve.ResolveVariable(config.Host) + config.Host, err = resolver.ResolveVariable(config.Host) if err != nil { return email.SMTPConfig{}, err } - config.From, err = resolve.ResolveVariable(config.From) + config.From, err = resolver.ResolveVariable(config.From) if err != nil { return email.SMTPConfig{}, err } - config.Username, err = resolve.ResolveVariable(config.Username) + config.Username, err = resolver.ResolveVariable(config.Username) if err != nil { return email.SMTPConfig{}, err } - config.Password, err = resolve.ResolveVariable(config.Password) + config.Password, err = resolver.ResolveVariable(config.Password) if err != nil { return email.SMTPConfig{}, err } @@ -119,32 +119,32 @@ func resolveEmailConfig(config email.Email) (email.Email, error) { var err error for i, to := range config.To { - config.To[i], err = resolve.ResolveVariable(to) + config.To[i], err = resolver.ResolveVariable(to) if err != nil { return email.Email{}, err } } for i, cc := range config.Cc { - config.Cc[i], err = resolve.ResolveVariable(cc) + config.Cc[i], err = resolver.ResolveVariable(cc) if err != nil { return email.Email{}, err } } for i, bcc := range config.Bcc { - config.Bcc[i], err = resolve.ResolveVariable(bcc) + config.Bcc[i], err = resolver.ResolveVariable(bcc) if err != nil { return email.Email{}, err } } - config.Subject, err = resolve.ResolveVariable(config.Subject) + config.Subject, err = resolver.ResolveVariable(config.Subject) if err != nil { return email.Email{}, err } - config.Body, err = resolve.ResolveVariable(config.Body) + config.Body, err = resolver.ResolveVariable(config.Body) if err != nil { return email.Email{}, err } diff --git a/pkg/notify/notifier/msteams.go b/pkg/notify/notifier/msteams.go index 2630203..b695aab 100644 --- a/pkg/notify/notifier/msteams.go +++ b/pkg/notify/notifier/msteams.go @@ -3,7 +3,7 @@ package notifier import ( "context" "fmt" - "heartbeats/pkg/notify/resolve" + "heartbeats/pkg/notify/resolver" "heartbeats/pkg/notify/services/msteams" "time" ) @@ -86,17 +86,17 @@ func (e MSTeamsNotifier) CheckResolveVariables() error { // - MSTeamsConfig: The resolved MS Teams configuration. // - error: An error if resolving any field fails. func resolveMSTeamsConfig(config MSTeamsConfig) (MSTeamsConfig, error) { - webhookURL, err := resolve.ResolveVariable(config.WebhookURL) + webhookURL, err := resolver.ResolveVariable(config.WebhookURL) if err != nil { return MSTeamsConfig{}, fmt.Errorf("cannot resolve webhook URL. %w", err) } - title, err := resolve.ResolveVariable(config.Title) + title, err := resolver.ResolveVariable(config.Title) if err != nil { return MSTeamsConfig{}, fmt.Errorf("cannot resolve MS Teams title. %w", err) } - text, err := resolve.ResolveVariable(config.Text) + text, err := resolver.ResolveVariable(config.Text) if err != nil { return MSTeamsConfig{}, fmt.Errorf("cannot resolve MS Teams text. %w", err) } diff --git a/pkg/notify/notifier/slack.go b/pkg/notify/notifier/slack.go index d817746..ac83caf 100644 --- a/pkg/notify/notifier/slack.go +++ b/pkg/notify/notifier/slack.go @@ -3,7 +3,7 @@ package notifier import ( "context" "fmt" - "heartbeats/pkg/notify/resolve" + "heartbeats/pkg/notify/resolver" "heartbeats/pkg/notify/services/slack" "heartbeats/pkg/notify/utils" "time" @@ -105,27 +105,27 @@ func (e SlackNotifier) CheckResolveVariables() error { // - SlackConfig: The resolved SlackConfig. // - error: An error if any of the configuration values cannot be resolved. func resolveSlackConfig(config SlackConfig) (SlackConfig, error) { - token, err := resolve.ResolveVariable(config.Token) + token, err := resolver.ResolveVariable(config.Token) if err != nil { return SlackConfig{}, fmt.Errorf("cannot resolve Slack token. %w", err) } - channel, err := resolve.ResolveVariable(config.Channel) + channel, err := resolver.ResolveVariable(config.Channel) if err != nil { return SlackConfig{}, fmt.Errorf("cannot resolve Slack channel. %w", err) } - title, err := resolve.ResolveVariable(config.Title) + title, err := resolver.ResolveVariable(config.Title) if err != nil { return SlackConfig{}, fmt.Errorf("cannot resolve Slack title. %w", err) } - text, err := resolve.ResolveVariable(config.Text) + text, err := resolver.ResolveVariable(config.Text) if err != nil { return SlackConfig{}, fmt.Errorf("cannot resolve Slack text. %w", err) } - colorTemplate, err := resolve.ResolveVariable(config.ColorTemplate) + colorTemplate, err := resolver.ResolveVariable(config.ColorTemplate) if err != nil { return SlackConfig{}, fmt.Errorf("cannot resolve Slack color template. %w", err) } diff --git a/pkg/notify/resolve/resolve.go b/pkg/notify/resolver/resolver.go similarity index 99% rename from pkg/notify/resolve/resolve.go rename to pkg/notify/resolver/resolver.go index d972f6e..f0128ee 100644 --- a/pkg/notify/resolve/resolve.go +++ b/pkg/notify/resolver/resolver.go @@ -1,4 +1,4 @@ -package resolve +package resolver import ( "bufio" diff --git a/pkg/notify/resolver/resolver_test.go b/pkg/notify/resolver/resolver_test.go new file mode 100644 index 0000000..46bd42e --- /dev/null +++ b/pkg/notify/resolver/resolver_test.go @@ -0,0 +1,78 @@ +package resolver + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResolveVariable(t *testing.T) { + t.Run("Resolve environment variable", func(t *testing.T) { + os.Setenv("TEST_ENV_VAR", "test_value") + defer os.Unsetenv("TEST_ENV_VAR") + + result, err := ResolveVariable("env:TEST_ENV_VAR") + assert.NoError(t, err) + assert.Equal(t, "test_value", result) + }) + + t.Run("Resolve non-existing environment variable", func(t *testing.T) { + result, err := ResolveVariable("env:NON_EXISTING_ENV_VAR") + assert.Error(t, err) + assert.Equal(t, "", result) + assert.Contains(t, err.Error(), "environment variable 'NON_EXISTING_ENV_VAR' not found") + }) + + t.Run("Resolve file variable", func(t *testing.T) { + fileContent := "key1=value1\nkey2=value2\n" + file, err := os.CreateTemp("", "testfile") + assert.NoError(t, err) + defer os.Remove(file.Name()) + + _, err = file.WriteString(fileContent) + assert.NoError(t, err) + file.Close() + + result, err := ResolveVariable("file:" + file.Name()) + assert.NoError(t, err) + assert.Equal(t, fileContent, result+"\n") + }) + + t.Run("Resolve file with key", func(t *testing.T) { + fileContent := "key1=value1\nkey2=value2\n" + file, err := os.CreateTemp("", "testfile") + assert.NoError(t, err) + defer os.Remove(file.Name()) + + _, err = file.WriteString(fileContent) + assert.NoError(t, err) + file.Close() + + result, err := ResolveVariable("file:" + file.Name() + "//key1") + assert.NoError(t, err) + assert.Equal(t, "value1", result) + }) + + t.Run("Resolve file with non-existing key", func(t *testing.T) { + fileContent := "key1=value1\nkey2=value2\n" + file, err := os.CreateTemp("", "testfile") + assert.NoError(t, err) + defer os.Remove(file.Name()) + + _, err = file.WriteString(fileContent) + assert.NoError(t, err) + file.Close() + + result, err := ResolveVariable("file:" + file.Name() + "//non_existing_key") + assert.Error(t, err) + assert.Equal(t, "", result) + assert.Contains(t, err.Error(), "Key 'non_existing_key' not found in file") + }) + + t.Run("Resolve plain string", func(t *testing.T) { + result, err := ResolveVariable("plain_string") + assert.NoError(t, err) + assert.Equal(t, "plain_string", result) + }) +} diff --git a/pkg/notify/services/slack/slack.go b/pkg/notify/services/slack/slack.go index b712dce..226bbf4 100644 --- a/pkg/notify/services/slack/slack.go +++ b/pkg/notify/services/slack/slack.go @@ -8,6 +8,8 @@ import ( "net/http" ) +const SLACK_WEBHOOK_URL = "https://slack.com/api/chat.postMessage" + // Slack represents the structure of a Slack message. type Slack struct { Channel string `json:"channel"` @@ -56,14 +58,12 @@ func New(headers map[string]string, skipTLS bool) *Client { // - *Response: The response from the Slack API. // - error: An error if sending the message fails. func (c *Client) Send(ctx context.Context, slackMessage Slack) (*Response, error) { - webhookURL := "https://slack.com/api/chat.postMessage" - data, err := json.Marshal(slackMessage) if err != nil { return nil, fmt.Errorf("error marshalling Slack message. %w", err) } - resp, err := c.HttpClient.DoRequest(ctx, "POST", webhookURL, data) + resp, err := c.HttpClient.DoRequest(ctx, "POST", SLACK_WEBHOOK_URL, data) if err != nil { return nil, fmt.Errorf("error sending HTTP request. %w", err) } diff --git a/pkg/notify/utils/request_test.go b/pkg/notify/utils/request_test.go new file mode 100644 index 0000000..37d09f1 --- /dev/null +++ b/pkg/notify/utils/request_test.go @@ -0,0 +1,74 @@ +package utils + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewHttpClient(t *testing.T) { + t.Run("Create new HttpClient", func(t *testing.T) { + headers := map[string]string{ + "Content-Type": "application/json", + } + client := NewHttpClient(headers, true) + + assert.NotNil(t, client) + assert.Equal(t, headers, client.Headers) + assert.True(t, client.SkipInsecure) + }) +} + +func TestHttpClient_DoRequest(t *testing.T) { + t.Run("Perform successful GET request", func(t *testing.T) { + headers := map[string]string{ + "Content-Type": "application/json", + } + client := NewHttpClient(headers, true) + + // Setup a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ok"}`)) + })) + defer server.Close() + + ctx := context.Background() + resp, err := client.DoRequest(ctx, http.MethodGet, server.URL, nil) + + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("Perform GET request with invalid URL", func(t *testing.T) { + headers := map[string]string{ + "Content-Type": "application/json", + } + client := NewHttpClient(headers, true) + + ctx := context.Background() + _, err := client.DoRequest(ctx, http.MethodGet, "http://invalid-url", nil) + + assert.Error(t, err) + }) +} + +func TestHttpClient_createHTTPClient(t *testing.T) { + t.Run("Create HTTP client with custom transport", func(t *testing.T) { + client := NewHttpClient(nil, true) + httpClient := client.createHTTPClient() + + assert.NotNil(t, httpClient) + assert.IsType(t, &http.Client{}, httpClient) + assert.IsType(t, &http.Transport{}, httpClient.Transport) + + transport := httpClient.Transport.(*http.Transport) + assert.NotNil(t, transport.TLSClientConfig) + assert.True(t, transport.TLSClientConfig.InsecureSkipVerify) + }) +} diff --git a/pkg/notify/utils/template.go b/pkg/notify/utils/template.go index b33ab29..9456a50 100644 --- a/pkg/notify/utils/template.go +++ b/pkg/notify/utils/template.go @@ -31,7 +31,7 @@ func FormatTemplate(name, tmpl string, intr interface{}) (string, error) { } buf := &bytes.Buffer{} - if err := t.Execute(buf, &intr); err != nil { + if err := t.Execute(buf, intr); err != nil { return "", fmt.Errorf("error executing template. %w", err) } diff --git a/pkg/notify/utils/template_test.go b/pkg/notify/utils/template_test.go new file mode 100644 index 0000000..ee4814a --- /dev/null +++ b/pkg/notify/utils/template_test.go @@ -0,0 +1,52 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFormatTemplate(t *testing.T) { + t.Run("Empty template", func(t *testing.T) { + result, err := FormatTemplate("empty", "", nil) + assert.Error(t, err) + assert.Equal(t, "", result) + assert.Equal(t, "template is empty", err.Error()) + }) + + t.Run("Simple template", func(t *testing.T) { + tmpl := "Hello, {{ .Name }}!" + data := map[string]string{"Name": "World"} + + result, err := FormatTemplate("simple", tmpl, data) + assert.NoError(t, err) + assert.Equal(t, "Hello, World!", result) + }) + + t.Run("Template with sprig function", func(t *testing.T) { + tmpl := "The date is {{ now | date \"2006-01-02\" }}." + data := map[string]string{} + + result, err := FormatTemplate("sprig", tmpl, data) + assert.NoError(t, err) + assert.Contains(t, result, "The date is ") + }) + + t.Run("Template with missing field", func(t *testing.T) { + tmpl := "Hello, {{ .Name | default \"\" }}!" + data := struct{ Name *string }{nil} + + result, err := FormatTemplate("missing", tmpl, data) + assert.NoError(t, err) + assert.Equal(t, "Hello, !", result) + }) + + t.Run("Complex template with multiple fields", func(t *testing.T) { + tmpl := "Hello, {{ .Name }}! Today is {{ .Day }}." + data := map[string]string{"Name": "Alice", "Day": "Monday"} + + result, err := FormatTemplate("complex", tmpl, data) + assert.NoError(t, err) + assert.Equal(t, "Hello, Alice! Today is Monday.", result) + }) +} diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 269897c..b1e59f4 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -1,7 +1,6 @@ package server import ( - "embed" "heartbeats/pkg/handlers" "heartbeats/pkg/logger" "io/fs" @@ -9,12 +8,11 @@ import ( ) // newRouter creates a new Server mux and appends Handlers -func newRouter(logger logger.Logger, staticFS embed.FS) http.Handler { +func newRouter(logger logger.Logger, staticFS fs.FS) http.Handler { mux := http.NewServeMux() // Handler for embedded static files - filesystem := fs.FS(staticFS) - staticContent, _ := fs.Sub(filesystem, "web/static") + staticContent, _ := fs.Sub(staticFS, "web/static") fileServer := http.FileServer(http.FS(staticContent)) mux.Handle("GET /static/", http.StripPrefix("/static/", fileServer)) diff --git a/pkg/server/routes_test.go b/pkg/server/routes_test.go new file mode 100644 index 0000000..4ec636e --- /dev/null +++ b/pkg/server/routes_test.go @@ -0,0 +1,178 @@ +package server + +import ( + "heartbeats/pkg/config" + "heartbeats/pkg/heartbeat" + "heartbeats/pkg/history" + "heartbeats/pkg/logger" + "heartbeats/pkg/notify" + "heartbeats/pkg/timer" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" +) + +// customAferoFS implements the fs.FS interface for afero.Fs +type customAferoFS struct { + fs afero.Fs +} + +// Open implements the fs.FS interface +func (a *customAferoFS) Open(name string) (fs.File, error) { + return a.fs.Open(name) +} + +// Convert the afero.Fs to customAferoFS +func aferoToCustomAferoFS(afs afero.Fs) fs.FS { + return &customAferoFS{fs: afs} +} + +func setupAferoFSForRoutes() afero.Fs { + aferoFS := afero.NewMemMapFs() + + staticFiles := []string{ + "web/static/css/heartbeats.css", + "web/templates/history.html", + "web/templates/heartbeats.html", + "web/templates/footer.html", + } + + for _, file := range staticFiles { + content, err := os.ReadFile(filepath.Join("../../", file)) + if err != nil { + panic(err) + } + + err = afero.WriteFile(aferoFS, file, content, 0644) + if err != nil { + panic(err) + } + } + + return aferoFS +} + +func TestNewRouter(t *testing.T) { + log := logger.NewLogger(true) + config.App.HeartbeatStore = heartbeat.NewStore() + config.App.NotificationStore = notify.NewStore() + config.HistoryStore = history.NewStore() + + h := &heartbeat.Heartbeat{ + Name: "test", + Enabled: new(bool), + Interval: &timer.Timer{Interval: new(time.Duration)}, + Grace: &timer.Timer{Interval: new(time.Duration)}, + } + *h.Enabled = true + *h.Interval.Interval = time.Minute + *h.Grace.Interval = time.Minute + + err := config.App.HeartbeatStore.Add("test", h) + assert.NoError(t, err) + + hist, err := history.NewHistory(10, 2) + assert.NoError(t, err) + + err = config.HistoryStore.Add("test", hist) + assert.NoError(t, err) + + aferoFS := setupAferoFSForRoutes() + customFS := aferoToCustomAferoFS(aferoFS) + + mux := newRouter(log, customFS) + + t.Run("GET /", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Body.String(), "Heartbeat") + }) + + t.Run("GET /ping/test", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ping/test", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "ok", rec.Body.String()) + }) + + t.Run("POST /ping/test", func(t *testing.T) { + req := httptest.NewRequest("POST", "/ping/test", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "ok", rec.Body.String()) + }) + + t.Run("GET /history/test", func(t *testing.T) { + req := httptest.NewRequest("GET", "/history/test", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("GET /healthz", func(t *testing.T) { + req := httptest.NewRequest("GET", "/healthz", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "ok", rec.Body.String()) + }) + + t.Run("POST /healthz", func(t *testing.T) { + req := httptest.NewRequest("POST", "/healthz", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "ok", rec.Body.String()) + }) + + t.Run("GET /metrics", func(t *testing.T) { + req := httptest.NewRequest("GET", "/metrics", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("GET /static/example.txt", func(t *testing.T) { + req := httptest.NewRequest("GET", "/static/css/heartbeats.css", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("Heartbeat not found", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ping/nonexistent", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Contains(t, rec.Body.String(), "Heartbeat 'nonexistent' not found") + }) +} diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go new file mode 100644 index 0000000..8dffe85 --- /dev/null +++ b/pkg/server/server_test.go @@ -0,0 +1,26 @@ +package server + +import ( + "context" + "embed" + "heartbeats/pkg/logger" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRun(t *testing.T) { + log := logger.NewLogger(true) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + time.Sleep(2 * time.Second) + cancel() + }() + + var staticFS embed.FS + err := Run(ctx, "localhost:8080", staticFS, log) + assert.NoError(t, err) +} diff --git a/pkg/timer/timer_test.go b/pkg/timer/timer_test.go new file mode 100644 index 0000000..552ae02 --- /dev/null +++ b/pkg/timer/timer_test.go @@ -0,0 +1,76 @@ +package timer + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTimer(t *testing.T) { + t.Run("UnmarshalYAML", func(t *testing.T) { + var tm Timer + durationStr := "2s" + err := tm.UnmarshalYAML(func(v interface{}) error { + *v.(*string) = durationStr + return nil + }) + assert.NoError(t, err) + assert.Equal(t, 2*time.Second, *tm.Interval) + }) + + t.Run("RunTimer", func(t *testing.T) { + tm := Timer{ + Interval: durationPtr(1 * time.Second), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var called bool + tm.RunTimer(ctx, func() { + called = true + }) + + time.Sleep(2 * time.Second) + assert.True(t, called) + }) + + t.Run("RunTimerWithCancel", func(t *testing.T) { + tm := Timer{ + Interval: durationPtr(1 * time.Second), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var called bool + tm.RunTimer(ctx, func() { + called = true + }) + + cancel() + time.Sleep(2 * time.Second) + assert.False(t, called) + }) + + t.Run("StopTimer", func(t *testing.T) { + tm := Timer{ + Interval: durationPtr(1 * time.Second), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var called bool + tm.RunTimer(ctx, func() { + called = true + }) + + tm.StopTimer() + time.Sleep(2 * time.Second) + assert.False(t, called) + }) +} + +func durationPtr(d time.Duration) *time.Duration { + return &d +} diff --git a/web/templates/heartbeat.html b/web/templates/heartbeats.html similarity index 100% rename from web/templates/heartbeat.html rename to web/templates/heartbeats.html