Skip to content

Commit

Permalink
Implement the batch-set of installers with labels
Browse files Browse the repository at this point in the history
  • Loading branch information
mna committed Dec 16, 2024
1 parent a9c45b1 commit cb63c06
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 8 deletions.
130 changes: 130 additions & 0 deletions server/datastore/mysql/software_installers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,55 @@ ON DUPLICATE KEY UPDATE
install_during_setup = COALESCE(?, install_during_setup)
`

const loadSoftwareInstallerID = `
SELECT
id
FROM
software_installers
WHERE
global_or_team_id = ? AND
-- this is guaranteed to select a single title_id, due to unique index
title_id IN (SELECT id FROM software_titles WHERE name = ? AND source = ? AND browser = '')
`

const deleteInstallerLabelsNotInList = `
DELETE
software_installer_labels
WHERE
software_installer_id = ? AND
label_id NOT IN (?)
`

const deleteAllInstallerLabels = `
DELETE
software_installer_labels
WHERE
software_installer_id = ?
`

const upsertInstallerLabels = `
INSERT INTO
software_installer_labels (
software_installer_id,
label_id,
exclude
)
VALUES
%s
ON DUPLICATE KEY UPDATE
exclude = VALUES(exclude)
`

const loadExistingInstallerLabels = `
SELECT
label_id,
exclude
FROM
software_installer_labels
WHERE
software_installer_id = ?
`

// use a team id of 0 if no-team
var globalOrTeamID uint
if tmID != nil {
Expand Down Expand Up @@ -1118,6 +1167,10 @@ ON DUPLICATE KEY UPDATE
}

for _, installer := range installers {
if installer.ValidatedLabels == nil {
return ctxerr.Wrapf(ctx, err, "labels have not been validated for installer with name %s", installer.Filename)
}

isRes, err := insertScriptContents(ctx, tx, installer.InstallScript)
if err != nil {
return ctxerr.Wrapf(ctx, err, "inserting install script contents for software installer with name %q", installer.Filename)
Expand Down Expand Up @@ -1203,6 +1256,83 @@ ON DUPLICATE KEY UPDATE
return ctxerr.Wrapf(ctx, err, "insert new/edited installer with name %q", installer.Filename)
}

// now that the software installer is created/updated, load its installer
// ID (cannot use res.LastInsertID due to the upsert statement, won't
// give the id in case of update)
var installerID uint
if err := sqlx.GetContext(ctx, tx, &installerID, loadSoftwareInstallerID, globalOrTeamID, installer.Title, installer.Source); err != nil {
return ctxerr.Wrapf(ctx, err, "load id of new/edited installer with name %q", installer.Filename)
}

// process the labels associated with that software installer
if len(installer.ValidatedLabels.ByName) == 0 {
// no label to apply, so just delete all existing labels if any
res, err := tx.ExecContext(ctx, deleteAllInstallerLabels, installerID)
if err != nil {
return ctxerr.Wrapf(ctx, err, "delete installer labels for %s", installer.Filename)
}

if n, _ := res.RowsAffected(); n > 0 && len(existing) > 0 {
// if it did delete a row, then the target changed so pending
// installs/uninstalls must be deleted
existing[0].IsMetadataModified = true
}
} else {
// there are new labels to apply, delete only the obsolete ones
labelIDs := make([]uint, 0, len(installer.ValidatedLabels.ByName))
for _, lbl := range installer.ValidatedLabels.ByName {
labelIDs = append(labelIDs, lbl.LabelID)
}
stmt, args, err := sqlx.In(deleteInstallerLabelsNotInList, installerID, labelIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "build statement to delete installer labels not in list")
}

res, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return ctxerr.Wrapf(ctx, err, "delete installer labels not in list for %s", installer.Filename)
}
if n, _ := res.RowsAffected(); n > 0 && len(existing) > 0 {
// if it did delete a row, then the target changed so pending
// installs/uninstalls must be deleted
existing[0].IsMetadataModified = true
}

excludeLabels := installer.ValidatedLabels.LabelScope == fleet.LabelScopeExcludeAny
if len(existing) > 0 && !existing[0].IsMetadataModified {
// load the remaining labels for that installer, so that we can detect
// if any label changed (if the counts differ, then labels did change,
// otherwise if the exclude bool changed, the target did change).
var existingLabels []struct {
LabelID uint `db:"label_id"`
Exclude bool `db:"exclude"`
}
if err := sqlx.SelectContext(ctx, tx, &existingLabels, loadExistingInstallerLabels, installerID); err != nil {
return ctxerr.Wrapf(ctx, err, "load existing labels for installer with name %q", installer.Filename)
}

if len(existingLabels) != len(labelIDs) {
existing[0].IsMetadataModified = true
}
if len(existingLabels) > 0 && existingLabels[0].Exclude != excludeLabels {
// same labels are provided, but the include <-> exclude changed
existing[0].IsMetadataModified = true
}
}

// upsert the new labels now that obsolete ones have been deleted
var upsertLabelArgs []any
for _, lblID := range labelIDs {
upsertLabelArgs = append(upsertLabelArgs, installerID, lblID, excludeLabels)
}
upsertLabelValues := strings.TrimSuffix(strings.Repeat("(?,?,?),", len(installer.ValidatedLabels.ByName)), ",")

_, err = tx.ExecContext(ctx, fmt.Sprintf(upsertInstallerLabels, upsertLabelValues), upsertLabelArgs...)
if err != nil {
return ctxerr.Wrapf(ctx, err, "insert new/edited labels for installer with name %q", installer.Filename)
}
}

// perform side effects if this was an update
if len(existing) > 0 {
if err := ds.runInstallerUpdateSideEffectsInTransaction(
Expand Down
12 changes: 4 additions & 8 deletions server/service/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"

"github.com/fleetdm/fleet/v4/server"
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
Expand Down Expand Up @@ -662,18 +663,13 @@ func (svc *Service) BatchValidateLabels(ctx context.Context, labelNames []string
return nil, nil
}

labels, err := svc.ds.LabelIDsByName(ctx, labelNames)
uniqueNames := server.RemoveDuplicatesFromSlice(labelNames)

labels, err := svc.ds.LabelIDsByName(ctx, uniqueNames)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting label IDs by name")
}

uniqueNames := make(map[string]bool)
for _, entry := range labelNames {
if _, value := uniqueNames[entry]; !value {
uniqueNames[entry] = true
}
}

if len(labels) != len(uniqueNames) {
return nil, &fleet.BadRequestError{
Message: "some or all the labels provided don't exist",
Expand Down

0 comments on commit cb63c06

Please sign in to comment.