From c861b3de66aec35194b2a0882b790f85d0228991 Mon Sep 17 00:00:00 2001 From: marco Date: Thu, 3 Oct 2024 11:03:35 +0200 Subject: [PATCH] context propagation: NewServer() --- cmd/crowdsec/api.go | 5 +++-- cmd/crowdsec/serve.go | 6 ++++-- pkg/apiserver/alerts_test.go | 6 +++--- pkg/apiserver/api_key_test.go | 3 +-- pkg/apiserver/apiserver.go | 4 +--- pkg/apiserver/apiserver_test.go | 33 +++++++++++++++++---------------- pkg/apiserver/jwt_test.go | 3 +-- pkg/apiserver/machines_test.go | 20 +++++++------------- 8 files changed, 37 insertions(+), 43 deletions(-) diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index c57b8d87cff..12eff952f77 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "runtime" @@ -14,12 +15,12 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) -func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { +func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.APIServer, error) { if cConfig.API.Server.OnlineClient == nil || cConfig.API.Server.OnlineClient.Credentials == nil { log.Info("push and pull to Central API disabled") } - apiServer, err := apiserver.NewServer(cConfig.API.Server) + apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server) if err != nil { return nil, fmt.Errorf("unable to run local API: %w", err) } diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index f1a658e9512..5c8a830668a 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -52,6 +52,8 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error { func reloadHandler(sig os.Signal) (*csconfig.Config, error) { var tmpFile string + ctx := context.TODO() + // re-initialize tombs acquisTomb = tomb.Tomb{} parsersTomb = tomb.Tomb{} @@ -74,7 +76,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return nil, fmt.Errorf("unable to init api server: %w", err) } @@ -374,7 +376,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error { cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return fmt.Errorf("api server init: %w", err) } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 0e89ddb2137..cd981f76542 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -65,7 +65,7 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur } func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { - router, config := NewAPITest(t) + router, config := NewAPITest(t, ctx) loginResp := LoginToTestAPI(t, ctx, router, config) return router, loginResp, config @@ -137,7 +137,7 @@ func TestCreateAlert(t *testing.T) { func TestCreateAlertChannels(t *testing.T) { ctx := context.Background() - apiServer, config := NewAPIServer(t) + apiServer, config := NewAPIServer(t, ctx) apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.InitController() @@ -437,7 +437,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { // cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"} cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"} cfg.API.Server.ListenURI = "::8080" - server, err := NewServer(cfg.API.Server) + server, err := NewServer(ctx, cfg.API.Server) require.NoError(t, err) err = server.InitController() diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index 014f255b892..e6ed68a6e0d 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -11,9 +11,8 @@ import ( ) func TestAPIKey(t *testing.T) { - router, config := NewAPITest(t) - ctx := context.Background() + router, config := NewAPITest(t, ctx) APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index a5db73d65d4..bdf2d4148cc 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -159,11 +159,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro // NewServer creates a LAPI server. // It sets up a gin router, a database client, and a controller. -func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { +func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APIServer, error) { var flushScheduler *gocron.Scheduler - ctx := context.TODO() - dbClient, err := database.NewClient(ctx, config.DbConfig) if err != nil { return nil, fmt.Errorf("unable to init database client: %w", err) diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index c3f69c5c365..0b3690a6c43 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -135,12 +135,12 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { return config } -func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { +func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Config) { config := LoadTestConfig(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) require.NoError(t, err) log.Printf("Creating new API server") @@ -149,8 +149,8 @@ func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { return apiServer, config } -func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { - apiServer, config := NewAPIServer(t) +func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { + apiServer, config := NewAPIServer(t, ctx) err := apiServer.InitController() require.NoError(t, err) @@ -161,12 +161,12 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { return router, config } -func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { +func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { config := LoadTestConfigForwardedFor(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) require.NoError(t, err) err = apiServer.InitController() @@ -302,28 +302,29 @@ func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.Datab } func TestWithWrongDBConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) config.API.Server.DbConfig.Type = "test" - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'") assert.Nil(t, apiServer) } func TestWithWrongFlushConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) maxItems := -1 config.API.Server.DbConfig.Flush.MaxItems = &maxItems - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) cstest.RequireErrorContains(t, err, "max_items can't be zero or negative") assert.Nil(t, apiServer) } func TestUnknownPath(t *testing.T) { - router, _ := NewAPITest(t) - ctx := context.Background() + router, _ := NewAPITest(t, ctx) w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil) @@ -349,6 +350,8 @@ ListenURI string `yaml:"listen_uri,omitempty"` //127.0 */ func TestLoggingDebugToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -378,12 +381,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) { err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) require.NoError(t, err) - api, err := NewServer(&cfg) + api, err := NewServer(ctx, &cfg) require.NoError(t, err) require.NotNil(t, api) - ctx := context.Background() - w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) @@ -402,6 +403,8 @@ func TestLoggingDebugToFileConfig(t *testing.T) { } func TestLoggingErrorToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -430,12 +433,10 @@ func TestLoggingErrorToFileConfig(t *testing.T) { err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) require.NoError(t, err) - api, err := NewServer(&cfg) + api, err := NewServer(ctx, &cfg) require.NoError(t, err) require.NotNil(t, api) - ctx := context.Background() - w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index 293cc38bd2c..710cf82ad00 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -11,9 +11,8 @@ import ( ) func TestLogin(t *testing.T) { - router, config := NewAPITest(t) - ctx := context.Background() + router, config := NewAPITest(t, ctx) body := CreateTestMachine(t, router, "") diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 44c370732c7..e60cec30e54 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -15,9 +15,8 @@ import ( ) func TestCreateMachine(t *testing.T) { - router, _ := NewAPITest(t) - ctx := context.Background() + router, _ := NewAPITest(t, ctx) // Create machine with invalid format w := httptest.NewRecorder() @@ -53,10 +52,9 @@ func TestCreateMachine(t *testing.T) { } func TestCreateMachineWithForwardedFor(t *testing.T) { - router, config := NewAPITestForwardedFor(t) - router.TrustedPlatform = "X-Real-IP" - ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) + router.TrustedPlatform = "X-Real-IP" // Create machine b, err := json.Marshal(MachineTest) @@ -79,9 +77,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { } func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { - router, config := NewAPITest(t) - ctx := context.Background() + router, config := NewAPITest(t, ctx) // Create machine b, err := json.Marshal(MachineTest) @@ -106,9 +103,8 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { } func TestCreateMachineWithoutForwardedFor(t *testing.T) { - router, config := NewAPITestForwardedFor(t) - ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) // Create machine b, err := json.Marshal(MachineTest) @@ -132,9 +128,8 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { } func TestCreateMachineAlreadyExist(t *testing.T) { - router, _ := NewAPITest(t) - ctx := context.Background() + router, _ := NewAPITest(t, ctx) body := CreateTestMachine(t, router, "") @@ -153,9 +148,8 @@ func TestCreateMachineAlreadyExist(t *testing.T) { } func TestAutoRegistration(t *testing.T) { - router, _ := NewAPITest(t) - ctx := context.Background() + router, _ := NewAPITest(t, ctx) // Invalid registration token / valid source IP regReq := MachineTest