diff --git a/cmd/crowdsec-cli/clinotifications/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go index 0641dd1a7d4..5489faa37c8 100644 --- a/cmd/crowdsec-cli/clinotifications/notifications.go +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -275,7 +275,8 @@ func (cli cliNotifications) newTestCmd() *cobra.Command { Args: cobra.ExactArgs(1), DisableAutoGenTag: true, ValidArgsFunction: cli.notificationConfigFilter, - PreRunE: func(_ *cobra.Command, args []string) error { + PreRunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() cfg := cli.cfg() pconfigs, err := cli.getPluginConfigs() if err != nil { @@ -286,7 +287,7 @@ func (cli cliNotifications) newTestCmd() *cobra.Command { return fmt.Errorf("plugin name: '%s' does not exist", args[0]) } // Create a single profile with plugin name as notification name - return pluginBroker.Init(cfg.PluginConfig, []*csconfig.ProfileCfg{ + return pluginBroker.Init(ctx, cfg.PluginConfig, []*csconfig.ProfileCfg{ { Notifications: []string{ pcfg.Name, @@ -377,12 +378,13 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return nil }, - RunE: func(_ *cobra.Command, _ []string) error { + RunE: func(cmd *cobra.Command, _ []string) error { var ( pluginBroker csplugin.PluginBroker pluginTomb tomb.Tomb ) + ctx := cmd.Context() cfg := cli.cfg() if alertOverride != "" { @@ -391,7 +393,7 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not } } - err := pluginBroker.Init(cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths) + err := pluginBroker.Init(ctx, cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths) if err != nil { return fmt.Errorf("can't initialize plugins: %w", err) } diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index 6ab41def16f..ccb0acf0209 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -40,7 +40,7 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP return nil, errors.New("plugins are enabled, but config_paths.plugin_dir is not defined") } - err = pluginBroker.Init(cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths) + err = pluginBroker.Init(ctx, cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths) if err != nil { return nil, fmt.Errorf("unable to run plugin broker: %w", err) } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index cd981f76542..4cc215c344f 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -72,8 +72,8 @@ func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.Wat } func LoginToTestAPI(t *testing.T, ctx context.Context, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { - body := CreateTestMachine(t, router, "") - ValidateMachine(t, "test", config.API.Server.DbConfig) + body := CreateTestMachine(t, ctx, router, "") + ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index b04ad687e4e..cdf99462c35 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -180,9 +180,7 @@ func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csc return router, config } -func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) { - ctx := context.TODO() - +func ValidateMachine(t *testing.T, ctx context.Context, machineID string, config *csconfig.DatabaseCfg) { dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) @@ -269,7 +267,7 @@ func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map return response, resp.Code } -func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string { +func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, token string) string { regReq := MachineTest regReq.RegistrationToken = token b, err := json.Marshal(regReq) @@ -277,8 +275,6 @@ func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string { body := string(b) - ctx := context.Background() - w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Set("User-Agent", UserAgent) diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index 710cf82ad00..f6f51763975 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -14,7 +14,7 @@ func TestLogin(t *testing.T) { ctx := context.Background() router, config := NewAPITest(t, ctx) - body := CreateTestMachine(t, router, "") + body := CreateTestMachine(t, ctx, router, "") // Login with machine not validated yet w := httptest.NewRecorder() @@ -53,7 +53,7 @@ func TestLogin(t *testing.T) { assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String()) // Validate machine - ValidateMachine(t, "test", config.API.Server.DbConfig) + ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) // Login with invalid password w = httptest.NewRecorder() diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index e60cec30e54..969f75707d6 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -131,7 +131,7 @@ func TestCreateMachineAlreadyExist(t *testing.T) { ctx := context.Background() router, _ := NewAPITest(t, ctx) - body := CreateTestMachine(t, router, "") + body := CreateTestMachine(t, ctx, router, "") w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index 31d7ac82fb2..e996fa9b68c 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -72,7 +72,7 @@ type ProfileAlert struct { Alert *models.Alert } -func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error { +func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error { pb.PluginChannel = make(chan ProfileAlert) pb.notificationConfigsByPluginType = make(map[string][][]byte) pb.notificationPluginByName = make(map[string]protobufs.NotifierServer) @@ -85,7 +85,7 @@ func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*cs if err := pb.loadConfig(configPaths.NotificationDir); err != nil { return fmt.Errorf("while loading plugin config: %w", err) } - if err := pb.loadPlugins(configPaths.PluginDir); err != nil { + if err := pb.loadPlugins(ctx, configPaths.PluginDir); err != nil { return fmt.Errorf("while loading plugin: %w", err) } pb.watcher = PluginWatcher{} @@ -230,7 +230,7 @@ func (pb *PluginBroker) verifyPluginBinaryWithProfile() error { return nil } -func (pb *PluginBroker) loadPlugins(path string) error { +func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error { binaryPaths, err := listFilesAtPath(path) if err != nil { return err @@ -265,7 +265,7 @@ func (pb *PluginBroker) loadPlugins(path string) error { return err } data = []byte(csstring.StrictExpand(string(data), os.LookupEnv)) - _, err = pluginClient.Configure(context.Background(), &protobufs.Config{Config: data}) + _, err = pluginClient.Configure(ctx, &protobufs.Config{Config: data}) if err != nil { return fmt.Errorf("while configuring %s: %w", pc.Name, err) } diff --git a/pkg/csplugin/broker_suite_test.go b/pkg/csplugin/broker_suite_test.go index 778bb2dfe2e..1210c67058a 100644 --- a/pkg/csplugin/broker_suite_test.go +++ b/pkg/csplugin/broker_suite_test.go @@ -1,6 +1,7 @@ package csplugin import ( + "context" "io" "os" "os/exec" @@ -96,6 +97,7 @@ func (s *PluginSuite) TearDownTest() { func (s *PluginSuite) SetupSubTest() { var err error + t := s.T() s.runDir, err = os.MkdirTemp("", "cs_plugin_test") @@ -127,6 +129,7 @@ func (s *PluginSuite) SetupSubTest() { func (s *PluginSuite) TearDownSubTest() { t := s.T() + if s.pluginBroker != nil { s.pluginBroker.Kill() s.pluginBroker = nil @@ -140,19 +143,24 @@ func (s *PluginSuite) TearDownSubTest() { os.Remove("./out") } -func (s *PluginSuite) InitBroker(procCfg *csconfig.PluginCfg) (*PluginBroker, error) { +func (s *PluginSuite) InitBroker(ctx context.Context, procCfg *csconfig.PluginCfg) (*PluginBroker, error) { pb := PluginBroker{} + if procCfg == nil { procCfg = &csconfig.PluginCfg{} } + profiles := csconfig.NewDefaultConfig().API.Server.Profiles profiles = append(profiles, &csconfig.ProfileCfg{ Notifications: []string{"dummy_default"}, }) - err := pb.Init(procCfg, profiles, &csconfig.ConfigurationPaths{ + + err := pb.Init(ctx, procCfg, profiles, &csconfig.ConfigurationPaths{ PluginDir: s.pluginDir, NotificationDir: s.notifDir, }) + s.pluginBroker = &pb + return s.pluginBroker, err } diff --git a/pkg/csplugin/broker_test.go b/pkg/csplugin/broker_test.go index 48f5a71f773..ae5a615b489 100644 --- a/pkg/csplugin/broker_test.go +++ b/pkg/csplugin/broker_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -53,6 +54,7 @@ func (s *PluginSuite) writeconfig(config PluginConfig) { } func (s *PluginSuite) TestBrokerInit() { + ctx := context.Background() tests := []struct { name string action func(*testing.T) @@ -135,20 +137,22 @@ func (s *PluginSuite) TestBrokerInit() { tc.action(t) } - _, err := s.InitBroker(&tc.procCfg) + _, err := s.InitBroker(ctx, &tc.procCfg) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func (s *PluginSuite) TestBrokerNoThreshold() { + ctx := context.Background() + var alerts []models.Alert DefaultEmptyTicker = 50 * time.Millisecond t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -187,6 +191,8 @@ func (s *PluginSuite) TestBrokerNoThreshold() { } func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { + ctx := context.Background() + // test grouping by "time" DefaultEmptyTicker = 50 * time.Millisecond @@ -198,7 +204,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { cfg.GroupWait = 1 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -224,6 +230,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { } func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() @@ -234,7 +241,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { cfg.GroupWait = 4 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -264,6 +271,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { } func (s *PluginSuite) TestBrokerRunGroupThreshold() { + ctx := context.Background() // test grouping by "size" DefaultEmptyTicker = 50 * time.Millisecond @@ -274,7 +282,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { cfg.GroupThreshold = 4 s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -318,6 +326,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { } func (s *PluginSuite) TestBrokerRunTimeThreshold() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() @@ -327,7 +336,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { cfg.GroupWait = 1 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -353,11 +362,12 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { } func (s *PluginSuite) TestBrokerRunSimple() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} diff --git a/pkg/csplugin/broker_win_test.go b/pkg/csplugin/broker_win_test.go index b7956bdcc0a..570f23e5015 100644 --- a/pkg/csplugin/broker_win_test.go +++ b/pkg/csplugin/broker_win_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -26,6 +27,7 @@ not if it will actually reject plugins with invalid permissions */ func (s *PluginSuite) TestBrokerInit() { + ctx := context.Background() tests := []struct { name string action func(*testing.T) @@ -59,16 +61,17 @@ func (s *PluginSuite) TestBrokerInit() { if tc.action != nil { tc.action(t) } - _, err := s.InitBroker(&tc.procCfg) + _, err := s.InitBroker(ctx, &tc.procCfg) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func (s *PluginSuite) TestBrokerRun() { + ctx := context.Background() t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} diff --git a/pkg/csplugin/watcher_test.go b/pkg/csplugin/watcher_test.go index b76c3c4eadd..84e63ec6493 100644 --- a/pkg/csplugin/watcher_test.go +++ b/pkg/csplugin/watcher_test.go @@ -15,11 +15,10 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -var ctx = context.Background() - func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) { testTomb.Kill(nil) <-pw.PluginEvents + if err := testTomb.Wait(); err != nil { log.Fatal(err) } @@ -46,13 +45,17 @@ func listenChannelWithTimeout(ctx context.Context, channel chan string) error { case <-ctx.Done(): return ctx.Err() } + return nil } func TestPluginWatcherInterval(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows because timing is not reliable") } + pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) testTomb := tomb.Tomb{} @@ -66,6 +69,7 @@ func TestPluginWatcherInterval(t *testing.T) { ct, cancel := context.WithTimeout(ctx, time.Microsecond) defer cancel() + err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") resetTestTomb(&testTomb, &pw) @@ -74,6 +78,7 @@ func TestPluginWatcherInterval(t *testing.T) { ct, cancel = context.WithTimeout(ctx, time.Millisecond*5) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) resetTestTomb(&testTomb, &pw) @@ -81,9 +86,12 @@ func TestPluginWatcherInterval(t *testing.T) { } func TestPluginAlertCountWatcher(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows because timing is not reliable") } + pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) configs := map[string]PluginConfig{ @@ -92,28 +100,34 @@ func TestPluginAlertCountWatcher(t *testing.T) { }, } testTomb := tomb.Tomb{} + pw.Init(configs, alertsByPluginName) pw.Start(&testTomb) // Channel won't contain any events since threshold is not crossed. ct, cancel := context.WithTimeout(ctx, time.Second) defer cancel() + err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") // Channel won't contain any events since threshold is not crossed. resetWatcherAlertCounter(&pw) insertNAlertsToPlugin(&pw, 4, "testPlugin") + ct, cancel = context.WithTimeout(ctx, time.Second) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") // Channel will contain an event since threshold is crossed. resetWatcherAlertCounter(&pw) insertNAlertsToPlugin(&pw, 5, "testPlugin") + ct, cancel = context.WithTimeout(ctx, time.Second) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) resetTestTomb(&testTomb, &pw)