Skip to content

Commit

Permalink
context propagation: pkg/csplugin (#3273)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc authored Oct 10, 2024
1 parent 50d115b commit 8ff58ee
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 33 deletions.
10 changes: 6 additions & 4 deletions cmd/crowdsec-cli/clinotifications/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ func (cli cliNotifications) newTestCmd() *cobra.Command {
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
ValidArgsFunction: cli.notificationConfigFilter,
PreRunE: func(_ *cobra.Command, args []string) error {
PreRunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
cfg := cli.cfg()
pconfigs, err := cli.getPluginConfigs()
if err != nil {
Expand All @@ -286,7 +287,7 @@ func (cli cliNotifications) newTestCmd() *cobra.Command {
return fmt.Errorf("plugin name: '%s' does not exist", args[0])
}
// Create a single profile with plugin name as notification name
return pluginBroker.Init(cfg.PluginConfig, []*csconfig.ProfileCfg{
return pluginBroker.Init(ctx, cfg.PluginConfig, []*csconfig.ProfileCfg{
{
Notifications: []string{
pcfg.Name,
Expand Down Expand Up @@ -377,12 +378,13 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not

return nil
},
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
var (
pluginBroker csplugin.PluginBroker
pluginTomb tomb.Tomb
)

ctx := cmd.Context()
cfg := cli.cfg()

if alertOverride != "" {
Expand All @@ -391,7 +393,7 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
}
}

err := pluginBroker.Init(cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths)
err := pluginBroker.Init(ctx, cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths)
if err != nil {
return fmt.Errorf("can't initialize plugins: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/crowdsec/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP
return nil, errors.New("plugins are enabled, but config_paths.plugin_dir is not defined")
}

err = pluginBroker.Init(cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths)
err = pluginBroker.Init(ctx, cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths)
if err != nil {
return nil, fmt.Errorf("unable to run plugin broker: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/apiserver/alerts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.Wat
}

func LoginToTestAPI(t *testing.T, ctx context.Context, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse {
body := CreateTestMachine(t, router, "")
ValidateMachine(t, "test", config.API.Server.DbConfig)
body := CreateTestMachine(t, ctx, router, "")
ValidateMachine(t, ctx, "test", config.API.Server.DbConfig)

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body))
Expand Down
8 changes: 2 additions & 6 deletions pkg/apiserver/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csc
return router, config
}

func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
ctx := context.TODO()

func ValidateMachine(t *testing.T, ctx context.Context, machineID string, config *csconfig.DatabaseCfg) {
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)

Expand Down Expand Up @@ -269,16 +267,14 @@ func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map
return response, resp.Code
}

func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string {
func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, token string) string {
regReq := MachineTest
regReq.RegistrationToken = token
b, err := json.Marshal(regReq)
require.NoError(t, err)

body := string(b)

ctx := context.Background()

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body))
req.Header.Set("User-Agent", UserAgent)
Expand Down
4 changes: 2 additions & 2 deletions pkg/apiserver/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestLogin(t *testing.T) {
ctx := context.Background()
router, config := NewAPITest(t, ctx)

body := CreateTestMachine(t, router, "")
body := CreateTestMachine(t, ctx, router, "")

// Login with machine not validated yet
w := httptest.NewRecorder()
Expand Down Expand Up @@ -53,7 +53,7 @@ func TestLogin(t *testing.T) {
assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String())

// Validate machine
ValidateMachine(t, "test", config.API.Server.DbConfig)
ValidateMachine(t, ctx, "test", config.API.Server.DbConfig)

// Login with invalid password
w = httptest.NewRecorder()
Expand Down
2 changes: 1 addition & 1 deletion pkg/apiserver/machines_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
ctx := context.Background()
router, _ := NewAPITest(t, ctx)

body := CreateTestMachine(t, router, "")
body := CreateTestMachine(t, ctx, router, "")

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body))
Expand Down
8 changes: 4 additions & 4 deletions pkg/csplugin/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type ProfileAlert struct {
Alert *models.Alert
}

func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error {
func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error {
pb.PluginChannel = make(chan ProfileAlert)
pb.notificationConfigsByPluginType = make(map[string][][]byte)
pb.notificationPluginByName = make(map[string]protobufs.NotifierServer)
Expand All @@ -85,7 +85,7 @@ func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*cs
if err := pb.loadConfig(configPaths.NotificationDir); err != nil {
return fmt.Errorf("while loading plugin config: %w", err)
}
if err := pb.loadPlugins(configPaths.PluginDir); err != nil {
if err := pb.loadPlugins(ctx, configPaths.PluginDir); err != nil {
return fmt.Errorf("while loading plugin: %w", err)
}
pb.watcher = PluginWatcher{}
Expand Down Expand Up @@ -230,7 +230,7 @@ func (pb *PluginBroker) verifyPluginBinaryWithProfile() error {
return nil
}

func (pb *PluginBroker) loadPlugins(path string) error {
func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error {
binaryPaths, err := listFilesAtPath(path)
if err != nil {
return err
Expand Down Expand Up @@ -265,7 +265,7 @@ func (pb *PluginBroker) loadPlugins(path string) error {
return err
}
data = []byte(csstring.StrictExpand(string(data), os.LookupEnv))
_, err = pluginClient.Configure(context.Background(), &protobufs.Config{Config: data})
_, err = pluginClient.Configure(ctx, &protobufs.Config{Config: data})
if err != nil {
return fmt.Errorf("while configuring %s: %w", pc.Name, err)
}
Expand Down
12 changes: 10 additions & 2 deletions pkg/csplugin/broker_suite_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package csplugin

import (
"context"
"io"
"os"
"os/exec"
Expand Down Expand Up @@ -96,6 +97,7 @@ func (s *PluginSuite) TearDownTest() {

func (s *PluginSuite) SetupSubTest() {
var err error

t := s.T()

s.runDir, err = os.MkdirTemp("", "cs_plugin_test")
Expand Down Expand Up @@ -127,6 +129,7 @@ func (s *PluginSuite) SetupSubTest() {

func (s *PluginSuite) TearDownSubTest() {
t := s.T()

if s.pluginBroker != nil {
s.pluginBroker.Kill()
s.pluginBroker = nil
Expand All @@ -140,19 +143,24 @@ func (s *PluginSuite) TearDownSubTest() {
os.Remove("./out")
}

func (s *PluginSuite) InitBroker(procCfg *csconfig.PluginCfg) (*PluginBroker, error) {
func (s *PluginSuite) InitBroker(ctx context.Context, procCfg *csconfig.PluginCfg) (*PluginBroker, error) {
pb := PluginBroker{}

if procCfg == nil {
procCfg = &csconfig.PluginCfg{}
}

profiles := csconfig.NewDefaultConfig().API.Server.Profiles
profiles = append(profiles, &csconfig.ProfileCfg{
Notifications: []string{"dummy_default"},
})
err := pb.Init(procCfg, profiles, &csconfig.ConfigurationPaths{

err := pb.Init(ctx, procCfg, profiles, &csconfig.ConfigurationPaths{
PluginDir: s.pluginDir,
NotificationDir: s.notifDir,
})

s.pluginBroker = &pb

return s.pluginBroker, err
}
24 changes: 17 additions & 7 deletions pkg/csplugin/broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package csplugin

import (
"bytes"
"context"
"encoding/json"
"io"
"os"
Expand Down Expand Up @@ -53,6 +54,7 @@ func (s *PluginSuite) writeconfig(config PluginConfig) {
}

func (s *PluginSuite) TestBrokerInit() {
ctx := context.Background()
tests := []struct {
name string
action func(*testing.T)
Expand Down Expand Up @@ -135,20 +137,22 @@ func (s *PluginSuite) TestBrokerInit() {
tc.action(t)
}

_, err := s.InitBroker(&tc.procCfg)
_, err := s.InitBroker(ctx, &tc.procCfg)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
}
}

func (s *PluginSuite) TestBrokerNoThreshold() {
ctx := context.Background()

var alerts []models.Alert

DefaultEmptyTicker = 50 * time.Millisecond

t := s.T()

pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)

tomb := tomb.Tomb{}
Expand Down Expand Up @@ -187,6 +191,8 @@ func (s *PluginSuite) TestBrokerNoThreshold() {
}

func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
ctx := context.Background()

// test grouping by "time"
DefaultEmptyTicker = 50 * time.Millisecond

Expand All @@ -198,7 +204,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
cfg.GroupWait = 1 * time.Second
s.writeconfig(cfg)

pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)

tomb := tomb.Tomb{}
Expand All @@ -224,6 +230,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
}

func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
ctx := context.Background()
DefaultEmptyTicker = 50 * time.Millisecond

t := s.T()
Expand All @@ -234,7 +241,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
cfg.GroupWait = 4 * time.Second
s.writeconfig(cfg)

pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)

tomb := tomb.Tomb{}
Expand Down Expand Up @@ -264,6 +271,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
}

func (s *PluginSuite) TestBrokerRunGroupThreshold() {
ctx := context.Background()
// test grouping by "size"
DefaultEmptyTicker = 50 * time.Millisecond

Expand All @@ -274,7 +282,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
cfg.GroupThreshold = 4
s.writeconfig(cfg)

pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)

tomb := tomb.Tomb{}
Expand Down Expand Up @@ -318,6 +326,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
}

func (s *PluginSuite) TestBrokerRunTimeThreshold() {
ctx := context.Background()
DefaultEmptyTicker = 50 * time.Millisecond

t := s.T()
Expand All @@ -327,7 +336,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
cfg.GroupWait = 1 * time.Second
s.writeconfig(cfg)

pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)

tomb := tomb.Tomb{}
Expand All @@ -353,11 +362,12 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
}

func (s *PluginSuite) TestBrokerRunSimple() {
ctx := context.Background()
DefaultEmptyTicker = 50 * time.Millisecond

t := s.T()

pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)

tomb := tomb.Tomb{}
Expand Down
7 changes: 5 additions & 2 deletions pkg/csplugin/broker_win_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package csplugin

import (
"bytes"
"context"
"encoding/json"
"io"
"os"
Expand All @@ -26,6 +27,7 @@ not if it will actually reject plugins with invalid permissions
*/

func (s *PluginSuite) TestBrokerInit() {
ctx := context.Background()
tests := []struct {
name string
action func(*testing.T)
Expand Down Expand Up @@ -59,16 +61,17 @@ func (s *PluginSuite) TestBrokerInit() {
if tc.action != nil {
tc.action(t)
}
_, err := s.InitBroker(&tc.procCfg)
_, err := s.InitBroker(ctx, &tc.procCfg)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
}
}

func (s *PluginSuite) TestBrokerRun() {
ctx := context.Background()
t := s.T()

pb, err := s.InitBroker(nil)
pb, err := s.InitBroker(ctx, nil)
require.NoError(t, err)

tomb := tomb.Tomb{}
Expand Down
Loading

0 comments on commit 8ff58ee

Please sign in to comment.