diff --git a/ee/server/service/maintained_apps.go b/ee/server/service/maintained_apps.go index 316a88a4571a..2c83ed39b3e6 100644 --- a/ee/server/service/maintained_apps.go +++ b/ee/server/service/maintained_apps.go @@ -123,7 +123,7 @@ func (svc *Service) AddFleetMaintainedApp( } // TODO: labels validations, for now just use empty struct - payload.ValidatedLabels = &fleet.LabelIndentsWithScope{} + payload.ValidatedLabels = &fleet.LabelIdentsWithScope{} // Create record in software installers table _, titleID, err = svc.ds.MatchOrCreateSoftwareInstaller(ctx, payload) diff --git a/ee/server/service/software_installers.go b/ee/server/service/software_installers.go index c8de13460675..1deefc6fcb74 100644 --- a/ee/server/service/software_installers.go +++ b/ee/server/service/software_installers.go @@ -103,7 +103,7 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet. return nil } -func (svc *Service) validateSoftwareLabels(ctx context.Context, labelsIncludeAny, labelsExcludeAny []string) (*fleet.LabelIndentsWithScope, error) { +func (svc *Service) validateSoftwareLabels(ctx context.Context, labelsIncludeAny, labelsExcludeAny []string) (*fleet.LabelIdentsWithScope, error) { var names []string var scope fleet.LabelScope switch { @@ -119,7 +119,7 @@ func (svc *Service) validateSoftwareLabels(ctx context.Context, labelsIncludeAny if len(names) == 0 { // nothing to validate, return empty result - return &fleet.LabelIndentsWithScope{}, nil + return &fleet.LabelIdentsWithScope{}, nil } byName, err := svc.BatchValidateLabels(ctx, names) @@ -127,7 +127,7 @@ func (svc *Service) validateSoftwareLabels(ctx context.Context, labelsIncludeAny return nil, err } - return &fleet.LabelIndentsWithScope{ + return &fleet.LabelIdentsWithScope{ LabelScope: scope, ByName: byName, }, nil diff --git a/server/datastore/mysql/software_installers.go b/server/datastore/mysql/software_installers.go index b1aba354068d..b604f7f84f9e 100644 --- a/server/datastore/mysql/software_installers.go +++ b/server/datastore/mysql/software_installers.go @@ -119,7 +119,8 @@ func (ds *Datastore) MatchOrCreateSoftwareInstaller(ctx context.Context, payload } } - stmt := ` + if err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { + stmt := ` INSERT INTO software_installers ( team_id, global_or_team_id, @@ -141,47 +142,49 @@ INSERT INTO software_installers ( fleet_library_app_id ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT name FROM users WHERE id = ?), (SELECT email FROM users WHERE id = ?), ?)` - args := []interface{}{ - tid, - globalOrTeamID, - titleID, - payload.StorageID, - payload.Filename, - payload.Extension, - payload.Version, - strings.Join(payload.PackageIDs, ","), - installScriptID, - payload.PreInstallQuery, - postInstallScriptID, - uninstallScriptID, - payload.Platform, - payload.SelfService, - payload.UserID, - payload.UserID, - payload.UserID, - payload.FleetLibraryAppID, - } + args := []interface{}{ + tid, + globalOrTeamID, + titleID, + payload.StorageID, + payload.Filename, + payload.Extension, + payload.Version, + strings.Join(payload.PackageIDs, ","), + installScriptID, + payload.PreInstallQuery, + postInstallScriptID, + uninstallScriptID, + payload.Platform, + payload.SelfService, + payload.UserID, + payload.UserID, + payload.UserID, + payload.FleetLibraryAppID, + } - res, err := ds.writer(ctx).ExecContext(ctx, stmt, args...) - if err != nil { - if IsDuplicate(err) { - // already exists for this team/no team - err = alreadyExists("SoftwareInstaller", payload.Title) + res, err := tx.ExecContext(ctx, stmt, args...) + if err != nil { + if IsDuplicate(err) { + // already exists for this team/no team + err = alreadyExists("SoftwareInstaller", payload.Title) + } + return err } - return 0, 0, ctxerr.Wrap(ctx, err, "insert software installer") - } - id, _ := res.LastInsertId() + id, _ := res.LastInsertId() + installerID = uint(id) //nolint:gosec // dismiss G115 - // TODO: how does should this check work in the context of editng an existing software installer to - // remove existing labels (i.e. switching from custom targets to all hosts)? - if payload.ValidatedLabels.LabelScope != "" { - if err := ds.upsertSoftwareInstallerLabels(ctx, uint(id), *payload.ValidatedLabels); err != nil { //nolint:gosec // dismiss G115 - return uint(id), titleID, ctxerr.Wrap(ctx, err, "upsert software installer labels") //nolint:gosec // dismiss G115 + if err := setOrUpdateSoftwareInstallerLabelsDB(ctx, tx, installerID, *payload.ValidatedLabels); err != nil { + return ctxerr.Wrap(ctx, err, "upsert software installer labels") } + + return nil + }); err != nil { + return 0, 0, ctxerr.Wrap(ctx, err, "insert software installer") } - return uint(id), titleID, nil //nolint:gosec // dismiss G115 + return installerID, titleID, nil } func (ds *Datastore) getOrGenerateSoftwareInstallerTitleID(ctx context.Context, payload *fleet.UploadSoftwareInstallerPayload) (uint, error) { @@ -234,29 +237,18 @@ func (ds *Datastore) addSoftwareTitleToMatchingSoftware(ctx context.Context, tit return ctxerr.Wrap(ctx, err, "adding fk reference in software to software_titles") } -func (ds *Datastore) upsertSoftwareInstallerLabels(ctx context.Context, installerID uint, labels fleet.LabelIndentsWithScope) error { - var exclude bool - switch labels.LabelScope { - case fleet.LabelScopeIncludeAny: - exclude = false - case fleet.LabelScopeExcludeAny: - exclude = true - default: - return errors.New("invalid label scope") - } - +// setOrUpdateSoftwareInstallerLabelsDB sets or updates the label associations for the specified software +// installer. If no labels are provided, it will remove all label associations with the software installer. +func setOrUpdateSoftwareInstallerLabelsDB(ctx context.Context, tx sqlx.ExtContext, installerID uint, labels fleet.LabelIdentsWithScope) error { labelIds := make([]uint, 0, len(labels.ByName)) for _, label := range labels.ByName { labelIds = append(labelIds, label.LabelID) } - level.Debug(ds.logger).Log("msg", "upsert software installer labels", "installer_id", installerID, "label_ids", fmt.Sprintf("%v", labelIds), "exclude", exclude) - // remove existing labels delArgs := []interface{}{installerID} delStmt := `DELETE FROM software_installer_labels WHERE software_installer_id = ?` if len(labelIds) > 0 { - // TODO: we might consider skipping this step which preserves existing labels and just deleting everything each time inStmt, args, err := sqlx.In(` AND label_id NOT IN (?)`, labelIds) if err != nil { return ctxerr.Wrap(ctx, err, "build delete existing software installer labels query") @@ -264,24 +256,37 @@ func (ds *Datastore) upsertSoftwareInstallerLabels(ctx context.Context, installe delArgs = append(delArgs, args...) delStmt += inStmt } - _, err := ds.writer(ctx).ExecContext(ctx, delStmt, delArgs...) + _, err := tx.ExecContext(ctx, delStmt, delArgs...) if err != nil { return ctxerr.Wrap(ctx, err, "delete existing software installer labels") } // insert new labels - stmt := `INSERT INTO software_installer_labels (software_installer_id, label_id, exclude) VALUES %s ON DUPLICATE KEY UPDATE software_installer_id = software_installer_id, label_id = label_id, exclude = VALUES(exclude)` - var placeholders string - var insertArgs []interface{} - for _, lid := range labelIds { - placeholders += "(?, ?, ?)," - insertArgs = append(insertArgs, installerID, lid, exclude) - } - placeholders = strings.TrimSuffix(placeholders, ",") + if len(labelIds) > 0 { + var exclude bool + switch labels.LabelScope { + case fleet.LabelScopeIncludeAny: + exclude = false + case fleet.LabelScopeExcludeAny: + exclude = true + default: + // this should never happen + return ctxerr.New(ctx, "invalid label scope") + } - _, err = ds.writer(ctx).ExecContext(ctx, fmt.Sprintf(stmt, placeholders), insertArgs...) - if err != nil { - return ctxerr.Wrap(ctx, err, "insert software installer label") + stmt := `INSERT INTO software_installer_labels (software_installer_id, label_id, exclude) VALUES %s ON DUPLICATE KEY UPDATE software_installer_id = software_installer_id, label_id = label_id, exclude = VALUES(exclude)` + var placeholders string + var insertArgs []interface{} + for _, lid := range labelIds { + placeholders += "(?, ?, ?)," + insertArgs = append(insertArgs, installerID, lid, exclude) + } + placeholders = strings.TrimSuffix(placeholders, ",") + + _, err = tx.ExecContext(ctx, fmt.Sprintf(stmt, placeholders), insertArgs...) + if err != nil { + return ctxerr.Wrap(ctx, err, "insert software installer label") + } } return nil diff --git a/server/fleet/labels.go b/server/fleet/labels.go index f56f27b83776..01538b6152cd 100644 --- a/server/fleet/labels.go +++ b/server/fleet/labels.go @@ -217,7 +217,7 @@ const ( LabelScopeIncludeAll LabelScope = "include_all" ) -type LabelIndentsWithScope struct { +type LabelIdentsWithScope struct { LabelScope LabelScope ByName map[string]LabelIdent } diff --git a/server/fleet/software_installer.go b/server/fleet/software_installer.go index 08459032f7ea..5c8a7a089ae6 100644 --- a/server/fleet/software_installer.go +++ b/server/fleet/software_installer.go @@ -340,7 +340,7 @@ type UploadSoftwareInstallerPayload struct { LabelsExcludeAny []string // names of "exclude any" labels // ValidatedLabels is a struct that contains the validated labels for the software installer. It // is nil if the labels have not been validated. - ValidatedLabels *LabelIndentsWithScope + ValidatedLabels *LabelIdentsWithScope } type UpdateSoftwareInstallerPayload struct { diff --git a/server/service/labels.go b/server/service/labels.go index 5e72724792cc..2b0ce6f59836 100644 --- a/server/service/labels.go +++ b/server/service/labels.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + "github.com/fleetdm/fleet/v4/server" authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz" "github.com/fleetdm/fleet/v4/server/contexts/ctxdb" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" @@ -662,18 +663,13 @@ func (svc *Service) BatchValidateLabels(ctx context.Context, labelNames []string return nil, nil } - labels, err := svc.ds.LabelIDsByName(ctx, labelNames) + uniqueNames := server.RemoveDuplicatesFromSlice(labelNames) + + labels, err := svc.ds.LabelIDsByName(ctx, uniqueNames) if err != nil { return nil, ctxerr.Wrap(ctx, err, "getting label IDs by name") } - uniqueNames := make(map[string]bool) - for _, entry := range labelNames { - if _, value := uniqueNames[entry]; !value { - uniqueNames[entry] = true - } - } - if len(labels) != len(uniqueNames) { return nil, &fleet.BadRequestError{ Message: "some or all the labels provided don't exist",