From fee3debdccc71b7b4848cea95e8da0ea276117df Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 20 Sep 2024 16:00:58 +0200 Subject: [PATCH] context propagation: pkg/database/machines (#3248) --- cmd/crowdsec-cli/climachine/machines.go | 56 ++++++++++++---------- cmd/crowdsec-cli/clisupport/support.go | 6 +-- pkg/apiserver/apic.go | 4 +- pkg/apiserver/apic_metrics.go | 10 ++-- pkg/apiserver/apiserver_test.go | 10 ++-- pkg/apiserver/controllers/v1/heartbeat.go | 4 +- pkg/apiserver/controllers/v1/machines.go | 4 +- pkg/apiserver/middlewares/v1/jwt.go | 15 +++--- pkg/apiserver/usage_metrics_test.go | 32 ++++++------- pkg/database/alerts.go | 4 +- pkg/database/machines.go | 58 +++++++++++------------ 11 files changed, 109 insertions(+), 94 deletions(-) diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go index 30948f43056..1fbedcf57fd 100644 --- a/cmd/crowdsec-cli/climachine/machines.go +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -1,6 +1,7 @@ package climachine import ( + "context" "encoding/csv" "encoding/json" "errors" @@ -210,11 +211,11 @@ func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error { return nil } -func (cli *cliMachines) List(out io.Writer, db *database.Client) error { +func (cli *cliMachines) List(ctx context.Context, out io.Writer, db *database.Client) error { // XXX: must use the provided db object, the one in the struct might be nil // (calling List directly skips the PersistentPreRunE) - machines, err := db.ListMachines() + machines, err := db.ListMachines(ctx) if err != nil { return fmt.Errorf("unable to list machines: %w", err) } @@ -251,8 +252,8 @@ func (cli *cliMachines) newListCmd() *cobra.Command { Example: `cscli machines list`, Args: cobra.NoArgs, DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.List(color.Output, cli.db) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) }, } @@ -278,8 +279,8 @@ func (cli *cliMachines) newAddCmd() *cobra.Command { cscli machines add MyTestMachine --auto cscli machines add MyTestMachine --password MyPassword cscli machines add -f- --auto > /tmp/mycreds.yaml`, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args, string(password), dumpFile, apiURL, interactive, autoAdd, force) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force) }, } @@ -294,7 +295,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`, return cmd } -func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { +func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { var ( err error machineID string @@ -353,7 +354,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri password := strfmt.Password(machinePassword) - _, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType) + _, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType) if err != nil { return fmt.Errorf("unable to create machine: %w", err) } @@ -399,6 +400,7 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp var err error cfg := cli.cfg() + ctx := cmd.Context() // need to load config and db because PersistentPreRunE is not called for completions @@ -407,13 +409,13 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp return nil, cobra.ShellCompDirectiveNoFileComp } - cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + cli.db, err = require.DBClient(ctx, cfg.DbConfig) if err != nil { cobra.CompError("unable to list machines " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp } - machines, err := cli.db.ListMachines() + machines, err := cli.db.ListMachines(ctx) if err != nil { cobra.CompError("unable to list machines " + err.Error()) return nil, cobra.ShellCompDirectiveNoFileComp @@ -430,9 +432,9 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp return ret, cobra.ShellCompDirectiveNoFileComp } -func (cli *cliMachines) delete(machines []string, ignoreMissing bool) error { +func (cli *cliMachines) delete(ctx context.Context, machines []string, ignoreMissing bool) error { for _, machineID := range machines { - if err := cli.db.DeleteWatcher(machineID); err != nil { + if err := cli.db.DeleteWatcher(ctx, machineID); err != nil { var notFoundErr *database.MachineNotFoundError if ignoreMissing && errors.As(err, ¬FoundErr) { return nil @@ -460,8 +462,8 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command { Aliases: []string{"remove"}, DisableAutoGenTag: true, ValidArgsFunction: cli.validMachineID, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args, ignoreMissing) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) }, } @@ -471,7 +473,7 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command { return cmd } -func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force bool) error { +func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error { if duration < 2*time.Minute && !notValidOnly { if yes, err := ask.YesNo( "The duration you provided is less than 2 minutes. "+ @@ -484,12 +486,12 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b } machines := []*ent.Machine{} - if pending, err := cli.db.QueryPendingMachine(); err == nil { + if pending, err := cli.db.QueryPendingMachine(ctx); err == nil { machines = append(machines, pending...) } if !notValidOnly { - if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil { + if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil { machines = append(machines, pending...) } } @@ -512,7 +514,7 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b } } - deleted, err := cli.db.BulkDeleteWatchers(machines) + deleted, err := cli.db.BulkDeleteWatchers(ctx, machines) if err != nil { return fmt.Errorf("unable to prune machines: %w", err) } @@ -540,8 +542,8 @@ cscli machines prune --duration 1h cscli machines prune --not-validated-only --force`, Args: cobra.NoArgs, DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, notValidOnly, force) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, notValidOnly, force) }, } @@ -553,8 +555,8 @@ cscli machines prune --not-validated-only --force`, return cmd } -func (cli *cliMachines) validate(machineID string) error { - if err := cli.db.ValidateMachine(machineID); err != nil { +func (cli *cliMachines) validate(ctx context.Context, machineID string) error { + if err := cli.db.ValidateMachine(ctx, machineID); err != nil { return fmt.Errorf("unable to validate machine '%s': %w", machineID, err) } @@ -571,8 +573,8 @@ func (cli *cliMachines) newValidateCmd() *cobra.Command { Example: `cscli machines validate "machine_name"`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.validate(args[0]) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(cmd.Context(), args[0]) }, } @@ -690,9 +692,11 @@ func (cli *cliMachines) newInspectCmd() *cobra.Command { Args: cobra.ExactArgs(1), DisableAutoGenTag: true, ValidArgsFunction: cli.validMachineID, - RunE: func(_ *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() machineID := args[0] - machine, err := cli.db.QueryMachineByID(machineID) + + machine, err := cli.db.QueryMachineByID(ctx, machineID) if err != nil { return fmt.Errorf("unable to read machine data '%s': %w", machineID, err) } diff --git a/cmd/crowdsec-cli/clisupport/support.go b/cmd/crowdsec-cli/clisupport/support.go index 7e41518805a..4474f5c8f11 100644 --- a/cmd/crowdsec-cli/clisupport/support.go +++ b/cmd/crowdsec-cli/clisupport/support.go @@ -210,7 +210,7 @@ func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *dat return nil } -func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting agents") if db == nil { @@ -220,7 +220,7 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { out := new(bytes.Buffer) cm := climachine.New(cli.cfg) - if err := cm.List(out, db); err != nil { + if err := cm.List(ctx, out, db); err != nil { return err } @@ -529,7 +529,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error { log.Warnf("could not collect bouncers information: %s", err) } - if err = cli.dumpAgents(zipWriter, db); err != nil { + if err = cli.dumpAgents(ctx, zipWriter, db); err != nil { log.Warnf("could not collect agents information: %s", err) } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index b5384c6cc5c..c79d5f88e3f 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -85,7 +85,9 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { func (a *apic) FetchScenariosListFromDB() ([]string, error) { scenarios := make([]string, 0) - machines, err := a.dbClient.ListMachines() + ctx := context.TODO() + + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index e5821e4c1e2..16b2328dbe9 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -27,7 +27,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, allMetrics := &models.AllMetrics{} metricsIds := make([]int, 0) - lps, err := a.dbClient.ListMachines() + lps, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, nil, err } @@ -186,7 +186,7 @@ func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { } func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { - machines, err := a.dbClient.ListMachines() + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -230,8 +230,8 @@ func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { }, nil } -func (a *apic) fetchMachineIDs() ([]string, error) { - machines, err := a.dbClient.ListMachines() +func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -277,7 +277,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { machineIDs := []string{} reloadMachineIDs := func() { - ids, err := a.fetchMachineIDs() + ids, err := a.fetchMachineIDs(ctx) if err != nil { log.Debugf("unable to get machines (%s), will retry", err) diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 89c75f35d21..0db1ee5dcdc 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -182,12 +182,12 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { } func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) { - ctx := context.Background() + ctx := context.TODO() dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - err = dbClient.ValidateMachine(machineID) + err = dbClient.ValidateMachine(ctx, machineID) require.NoError(t, err) } @@ -197,7 +197,7 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - machines, err := dbClient.ListMachines() + machines, err := dbClient.ListMachines(ctx) require.NoError(t, err) for _, machine := range machines { @@ -332,7 +332,7 @@ func TestUnknownPath(t *testing.T) { req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) + assert.Equal(t, http.StatusNotFound, w.Code) } /* @@ -390,7 +390,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) { req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) + assert.Equal(t, http.StatusNotFound, w.Code) // wait for the request to happen time.Sleep(500 * time.Millisecond) diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index e1231eaa9ec..799b736ccfe 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -9,7 +9,9 @@ import ( func (c *Controller) HeartBeat(gctx *gin.Context) { machineID, _ := getMachineIDFromContext(gctx) - if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { + ctx := gctx.Request.Context() + + if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 0030f7d3b39..ff59e389cb1 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -46,6 +46,8 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool, } func (c *Controller) CreateMachine(gctx *gin.Context) { + ctx := gctx.Request.Context() + var input models.WatcherRegistrationRequest if err := gctx.ShouldBindJSON(&input); err != nil { @@ -66,7 +68,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) { return } - if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { + if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 64406deff3e..17ca5b28359 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -55,6 +55,7 @@ type authInput struct { } func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { + ctx := c.Request.Context() ret := authInput{} if j.TlsAuth == nil { @@ -76,7 +77,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if ent.IsNotFound(err) { // Machine was not found, let's create it logger.Infof("machine %s not found, create it", ret.machineID) @@ -91,7 +92,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { password := strfmt.Password(pwd) - ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) + ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType) if err != nil { return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) } @@ -175,6 +176,8 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { auth *authInput ) + ctx := c.Request.Context() + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { auth, err = j.authTLS(c) if err != nil { @@ -198,7 +201,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { } } - err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -208,7 +211,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { clientIP := c.ClientIP() if auth.clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -218,7 +221,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" { log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress) - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication @@ -231,7 +234,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil { + if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil { log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) log.Errorf("bad user agent from : %s", clientIP) diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go index 019de5fb970..b231fb22ad8 100644 --- a/pkg/apiserver/usage_metrics_test.go +++ b/pkg/apiserver/usage_metrics_test.go @@ -30,7 +30,7 @@ func TestLPMetrics(t *testing.T) { name: "empty metrics for LP", body: `{ }`, - expectedStatusCode: 400, + expectedStatusCode: http.StatusBadRequest, expectedResponse: "Missing log processor data", authType: PASSWORD, }, @@ -50,7 +50,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedResponse: "", expectedOSName: "foo", @@ -74,7 +74,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedResponse: "", expectedOSName: "foo", @@ -98,7 +98,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 400, + expectedStatusCode: http.StatusBadRequest, expectedResponse: "Missing remediation component data", authType: APIKEY, }, @@ -117,7 +117,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedResponse: "", expectedMetricsCount: 1, expectedFeatureFlags: "a,b,c", @@ -138,7 +138,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 422, + expectedStatusCode: http.StatusUnprocessableEntity, expectedResponse: "log_processors.0.datasources in body is required", authType: PASSWORD, }, @@ -157,7 +157,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedOSName: "foo", expectedOSVersion: "42", @@ -179,7 +179,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 422, + expectedStatusCode: http.StatusUnprocessableEntity, expectedResponse: "log_processors.0.os.name in body is required", authType: PASSWORD, }, @@ -199,7 +199,7 @@ func TestLPMetrics(t *testing.T) { assert.Equal(t, tt.expectedStatusCode, w.Code) assert.Contains(t, w.Body.String(), tt.expectedResponse) - machine, _ := dbClient.QueryMachineByID("test") + machine, _ := dbClient.QueryMachineByID(ctx, "test") metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) @@ -233,7 +233,7 @@ func TestRCMetrics(t *testing.T) { name: "empty metrics for RC", body: `{ }`, - expectedStatusCode: 400, + expectedStatusCode: http.StatusBadRequest, expectedResponse: "Missing remediation component data", authType: APIKEY, }, @@ -251,7 +251,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedResponse: "", expectedOSName: "foo", @@ -273,7 +273,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedResponse: "", expectedOSName: "foo", @@ -295,7 +295,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 400, + expectedStatusCode: http.StatusBadRequest, expectedResponse: "Missing log processor data", authType: PASSWORD, }, @@ -312,7 +312,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedResponse: "", expectedMetricsCount: 1, expectedFeatureFlags: "a,b,c", @@ -331,7 +331,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedOSName: "foo", expectedOSVersion: "42", @@ -351,7 +351,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 422, + expectedStatusCode: http.StatusUnprocessableEntity, expectedResponse: "remediation_components.0.os.name in body is required", authType: APIKEY, }, diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 3dfb0dc8197..d2760a209f9 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -687,8 +687,10 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str err error ) + ctx := context.TODO() + if machineID != "" { - owner, err = c.QueryMachineByID(machineID) + owner, err = c.QueryMachineByID(ctx, machineID) if err != nil { if !errors.Is(err, UserNotExists) { return nil, fmt.Errorf("machine '%s': %w", machineID, err) diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 27d737e625e..d8c02825312 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -72,7 +72,7 @@ func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string, return nil } -func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { +func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { c.Log.Warningf("CreateMachine: %s", err) @@ -82,20 +82,20 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA machineExist, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(*machineID)). - Select(machine.FieldMachineId).Strings(c.CTX) + Select(machine.FieldMachineId).Strings(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } if len(machineExist) > 0 { if force { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) } - machine, err := c.QueryMachineByID(*machineID) + machine, err := c.QueryMachineByID(ctx, *machineID) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } @@ -113,7 +113,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA SetIpAddress(ipAddress). SetIsValidated(isValidated). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) @@ -122,11 +122,11 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA return machine, nil } -func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { +func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.Machine, error) { machine, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(machineID)). - Only(c.CTX) + Only(ctx) if err != nil { c.Log.Warningf("QueryMachineByID : %s", err) return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID) @@ -135,8 +135,8 @@ func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { return machine, nil } -func (c *Client) ListMachines() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().All(c.CTX) +func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing machines: %s", err) } @@ -144,8 +144,8 @@ func (c *Client) ListMachines() ([]*ent.Machine, error) { return machines, nil } -func (c *Client) ValidateMachine(machineID string) error { - rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(c.CTX) +func (c *Client) ValidateMachine(ctx context.Context, machineID string) error { + rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "validating machine: %s", err) } @@ -157,8 +157,8 @@ func (c *Client) ValidateMachine(machineID string) error { return nil } -func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) +func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(ctx) if err != nil { c.Log.Warningf("QueryPendingMachine : %s", err) return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err) @@ -167,11 +167,11 @@ func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { return machines, nil } -func (c *Client) DeleteWatcher(name string) error { +func (c *Client) DeleteWatcher(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Machine. Delete(). Where(machine.MachineIdEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } @@ -183,13 +183,13 @@ func (c *Client) DeleteWatcher(name string) error { return nil } -func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) { +func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine) (int, error) { ids := make([]int, len(machines)) for i, b := range machines { ids[i] = b.ID } - nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(c.CTX) + nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, err } @@ -197,8 +197,8 @@ func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) { return nbDeleted, nil } -func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX) +func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error { + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err) } @@ -206,11 +206,11 @@ func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { return nil } -func (c *Client) UpdateMachineScenarios(scenarios string, id int) error { +func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetUpdatedAt(time.Now().UTC()). SetScenarios(scenarios). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine in database: %w", err) } @@ -218,10 +218,10 @@ func (c *Client) UpdateMachineScenarios(scenarios string, id int) error { return nil } -func (c *Client) UpdateMachineIP(ipAddr string, id int) error { +func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetIpAddress(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine IP in database: %w", err) } @@ -229,10 +229,10 @@ func (c *Client) UpdateMachineIP(ipAddr string, id int) error { return nil } -func (c *Client) UpdateMachineVersion(ipAddr string, id int) error { +func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetVersion(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine version in database: %w", err) } @@ -240,8 +240,8 @@ func (c *Client) UpdateMachineVersion(ipAddr string, id int) error { return nil } -func (c *Client) IsMachineRegistered(machineID string) (bool, error) { - exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX) +func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) { + exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx) if err != nil { return false, err } @@ -257,11 +257,11 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) { return false, nil } -func (c *Client) QueryMachinesInactiveSince(t time.Time) ([]*ent.Machine, error) { +func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) { return c.Ent.Machine.Query().Where( machine.Or( machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)), machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)), ), - ).All(c.CTX) + ).All(ctx) }