From b9bccfa56f3393dccf19ca97b4a2673efc0feaff Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Wed, 9 Oct 2024 13:06:03 +0200 Subject: [PATCH] context propagation: pkg/apiserver (#3272) * context propagation: apic.Push() * context propagation: NewServer() * lint --- .golangci.yml | 2 +- cmd/crowdsec-cli/clipapi/papi.go | 2 +- cmd/crowdsec/api.go | 10 ++++++-- cmd/crowdsec/serve.go | 10 ++++---- pkg/apiserver/alerts_test.go | 6 ++--- pkg/apiserver/api_key_test.go | 3 +-- pkg/apiserver/apic.go | 12 +++++----- pkg/apiserver/apic_test.go | 2 +- pkg/apiserver/apiserver.go | 10 ++++---- pkg/apiserver/apiserver_test.go | 40 ++++++++++++++++---------------- pkg/apiserver/jwt_test.go | 3 +-- pkg/apiserver/machines_test.go | 20 ++++++---------- 12 files changed, 59 insertions(+), 61 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 786bb18d8e7..4909d3e60c0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -321,7 +321,7 @@ issues: # `err` is often shadowed, we may continue to do it - linters: - govet - text: "shadow: declaration of \"err\" shadows declaration" + text: "shadow: declaration of \"(err|ctx)\" shadows declaration" - linters: - errcheck diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index b8101a0fb34..461215c3a39 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -127,7 +127,7 @@ func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client return fmt.Errorf("unable to initialize API client: %w", err) } - t.Go(apic.Push) + t.Go(func() error { return apic.Push(ctx) }) papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) if err != nil { diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index c57b8d87cff..6ab41def16f 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "runtime" @@ -14,12 +15,12 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) -func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { +func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.APIServer, error) { if cConfig.API.Server.OnlineClient == nil || cConfig.API.Server.OnlineClient.Credentials == nil { log.Info("push and pull to Central API disabled") } - apiServer, err := apiserver.NewServer(cConfig.API.Server) + apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server) if err != nil { return nil, fmt.Errorf("unable to run local API: %w", err) } @@ -58,11 +59,14 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { func serveAPIServer(apiServer *apiserver.APIServer) { apiReady := make(chan bool, 1) + apiTomb.Go(func() error { defer trace.CatchPanic("crowdsec/serveAPIServer") + go func() { defer trace.CatchPanic("crowdsec/runAPIServer") log.Debugf("serving API after %s ms", time.Since(crowdsecT0)) + if err := apiServer.Run(apiReady); err != nil { log.Fatal(err) } @@ -76,6 +80,7 @@ func serveAPIServer(apiServer *apiserver.APIServer) { <-apiTomb.Dying() // lock until go routine is dying pluginTomb.Kill(nil) log.Infof("serve: shutting down api server") + return apiServer.Shutdown() }) <-apiReady @@ -87,5 +92,6 @@ func hasPlugins(profiles []*csconfig.ProfileCfg) bool { return true } } + return false } diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index f1a658e9512..14602c425fe 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -52,6 +52,8 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error { func reloadHandler(sig os.Signal) (*csconfig.Config, error) { var tmpFile string + ctx := context.TODO() + // re-initialize tombs acquisTomb = tomb.Tomb{} parsersTomb = tomb.Tomb{} @@ -74,7 +76,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return nil, fmt.Errorf("unable to init api server: %w", err) } @@ -88,7 +90,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { return nil, err } - if err := hub.Load(); err != nil { + if err = hub.Load(); err != nil { return nil, err } @@ -374,7 +376,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error { cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return fmt.Errorf("api server init: %w", err) } @@ -390,7 +392,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error { return err } - if err := hub.Load(); err != nil { + if err = hub.Load(); err != nil { return err } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 0e89ddb2137..cd981f76542 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -65,7 +65,7 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur } func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { - router, config := NewAPITest(t) + router, config := NewAPITest(t, ctx) loginResp := LoginToTestAPI(t, ctx, router, config) return router, loginResp, config @@ -137,7 +137,7 @@ func TestCreateAlert(t *testing.T) { func TestCreateAlertChannels(t *testing.T) { ctx := context.Background() - apiServer, config := NewAPIServer(t) + apiServer, config := NewAPIServer(t, ctx) apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.InitController() @@ -437,7 +437,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { // cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"} cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"} cfg.API.Server.ListenURI = "::8080" - server, err := NewServer(cfg.API.Server) + server, err := NewServer(ctx, cfg.API.Server) require.NoError(t, err) err = server.InitController() diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index 014f255b892..e6ed68a6e0d 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -11,9 +11,8 @@ import ( ) func TestAPIKey(t *testing.T) { - router, config := NewAPITest(t) - ctx := context.Background() + router, config := NewAPITest(t, ctx) APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index e62bc663c16..a2fb0e85749 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -256,7 +256,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient } // keep track of all alerts in cache and push it to CAPI every PushInterval. -func (a *apic) Push() error { +func (a *apic) Push(ctx context.Context) error { defer trace.CatchPanic("lapi/pushToAPIC") var cache models.AddSignalsRequest @@ -276,7 +276,7 @@ func (a *apic) Push() error { return nil } - go a.Send(&cache) + go a.Send(ctx, &cache) return nil case <-ticker.C: @@ -289,7 +289,7 @@ func (a *apic) Push() error { a.mu.Unlock() log.Infof("Signal push: %d signals to push", len(cacheCopy)) - go a.Send(&cacheCopy) + go a.Send(ctx, &cacheCopy) } case alerts := <-a.AlertsAddChan: var signals []*models.AddSignalsRequestItem @@ -351,7 +351,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig return true } -func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { +func (a *apic) Send(ctx context.Context, cacheOrig *models.AddSignalsRequest) { /*we do have a problem with this : The apic.Push background routine reads from alertToPush chan. This chan is filled by Controller.CreateAlert @@ -375,7 +375,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { for { if pageEnd >= len(cache) { send = cache[pageStart:] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() @@ -389,7 +389,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { } send = cache[pageStart:pageEnd] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 51b1f43c707..b52dc9e44cc 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -1134,7 +1134,7 @@ func TestAPICPush(t *testing.T) { api.Shutdown() }() - err = api.Push() + err = api.Push(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount()) }) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 8fe500c7f52..bdf2d4148cc 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -159,11 +159,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro // NewServer creates a LAPI server. // It sets up a gin router, a database client, and a controller. -func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { +func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APIServer, error) { var flushScheduler *gocron.Scheduler - ctx := context.TODO() - dbClient, err := database.NewClient(ctx, config.DbConfig) if err != nil { return nil, fmt.Errorf("unable to init database client: %w", err) @@ -300,8 +298,8 @@ func (s *APIServer) Router() (*gin.Engine, error) { return s.router, nil } -func (s *APIServer) apicPush() error { - if err := s.apic.Push(); err != nil { +func (s *APIServer) apicPush(ctx context.Context) error { + if err := s.apic.Push(ctx); err != nil { log.Errorf("capi push: %s", err) return err } @@ -337,7 +335,7 @@ func (s *APIServer) papiSync() error { } func (s *APIServer) initAPIC(ctx context.Context) { - s.apic.pushTomb.Go(s.apicPush) + s.apic.pushTomb.Go(func() error { return s.apicPush(ctx) }) s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) }) // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index c3f69c5c365..b04ad687e4e 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -3,7 +3,6 @@ package apiserver import ( "context" "encoding/json" - "fmt" "net/http" "net/http/httptest" "os" @@ -41,7 +40,7 @@ var ( MachineID: &testMachineID, Password: &testPassword, } - UserAgent = fmt.Sprintf("crowdsec-test/%s", version.Version) + UserAgent = "crowdsec-test/" + version.Version emptyBody = strings.NewReader("") ) @@ -135,12 +134,12 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { return config } -func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { +func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Config) { config := LoadTestConfig(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) require.NoError(t, err) log.Printf("Creating new API server") @@ -149,8 +148,8 @@ func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { return apiServer, config } -func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { - apiServer, config := NewAPIServer(t) +func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { + apiServer, config := NewAPIServer(t, ctx) err := apiServer.InitController() require.NoError(t, err) @@ -161,12 +160,12 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { return router, config } -func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { +func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { config := LoadTestConfigForwardedFor(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) require.NoError(t, err) err = apiServer.InitController() @@ -302,28 +301,29 @@ func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.Datab } func TestWithWrongDBConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) config.API.Server.DbConfig.Type = "test" - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'") assert.Nil(t, apiServer) } func TestWithWrongFlushConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) maxItems := -1 config.API.Server.DbConfig.Flush.MaxItems = &maxItems - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) cstest.RequireErrorContains(t, err, "max_items can't be zero or negative") assert.Nil(t, apiServer) } func TestUnknownPath(t *testing.T) { - router, _ := NewAPITest(t) - ctx := context.Background() + router, _ := NewAPITest(t, ctx) w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil) @@ -349,6 +349,8 @@ ListenURI string `yaml:"listen_uri,omitempty"` //127.0 */ func TestLoggingDebugToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -370,7 +372,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) + expectedFile := filepath.Join(tempDir, "crowdsec_api.log") expectedLines := []string{"/test42"} cfg.LogLevel = ptr.Of(log.DebugLevel) @@ -378,12 +380,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) { err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) require.NoError(t, err) - api, err := NewServer(&cfg) + api, err := NewServer(ctx, &cfg) require.NoError(t, err) require.NotNil(t, api) - ctx := context.Background() - w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) @@ -402,6 +402,8 @@ func TestLoggingDebugToFileConfig(t *testing.T) { } func TestLoggingErrorToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -423,19 +425,17 @@ func TestLoggingErrorToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) + expectedFile := filepath.Join(tempDir, "crowdsec_api.log") cfg.LogLevel = ptr.Of(log.ErrorLevel) // Configure logging err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) require.NoError(t, err) - api, err := NewServer(&cfg) + api, err := NewServer(ctx, &cfg) require.NoError(t, err) require.NotNil(t, api) - ctx := context.Background() - w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index 293cc38bd2c..710cf82ad00 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -11,9 +11,8 @@ import ( ) func TestLogin(t *testing.T) { - router, config := NewAPITest(t) - ctx := context.Background() + router, config := NewAPITest(t, ctx) body := CreateTestMachine(t, router, "") diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 44c370732c7..e60cec30e54 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -15,9 +15,8 @@ import ( ) func TestCreateMachine(t *testing.T) { - router, _ := NewAPITest(t) - ctx := context.Background() + router, _ := NewAPITest(t, ctx) // Create machine with invalid format w := httptest.NewRecorder() @@ -53,10 +52,9 @@ func TestCreateMachine(t *testing.T) { } func TestCreateMachineWithForwardedFor(t *testing.T) { - router, config := NewAPITestForwardedFor(t) - router.TrustedPlatform = "X-Real-IP" - ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) + router.TrustedPlatform = "X-Real-IP" // Create machine b, err := json.Marshal(MachineTest) @@ -79,9 +77,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { } func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { - router, config := NewAPITest(t) - ctx := context.Background() + router, config := NewAPITest(t, ctx) // Create machine b, err := json.Marshal(MachineTest) @@ -106,9 +103,8 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { } func TestCreateMachineWithoutForwardedFor(t *testing.T) { - router, config := NewAPITestForwardedFor(t) - ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) // Create machine b, err := json.Marshal(MachineTest) @@ -132,9 +128,8 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { } func TestCreateMachineAlreadyExist(t *testing.T) { - router, _ := NewAPITest(t) - ctx := context.Background() + router, _ := NewAPITest(t, ctx) body := CreateTestMachine(t, router, "") @@ -153,9 +148,8 @@ func TestCreateMachineAlreadyExist(t *testing.T) { } func TestAutoRegistration(t *testing.T) { - router, _ := NewAPITest(t) - ctx := context.Background() + router, _ := NewAPITest(t, ctx) // Invalid registration token / valid source IP regReq := MachineTest