Skip to content

Commit

Permalink
Add version pinning to plugin catalog (#24960)
Browse files Browse the repository at this point in the history
Adds the ability to pin a version for a specific plugin type + name to enable an easier plugin upgrade UX. After pinning and reloading, that version should be the only version in use.

No HTTP API implementation yet for managing pins, so no user-facing effects yet.
  • Loading branch information
tomhjp authored Jan 26, 2024
1 parent 55d5880 commit af27ab3
Show file tree
Hide file tree
Showing 17 changed files with 693 additions and 186 deletions.
12 changes: 11 additions & 1 deletion builtin/logical/database/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,17 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
return nil, err
}

dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger)
// Override the configured version if there is a pinned version.
pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName)
if err != nil {
return nil, err
}
pluginVersion := config.PluginVersion
if pinnedVersion != "" {
pluginVersion = pinnedVersion
}

dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger)
if err != nil {
return nil, fmt.Errorf("unable to create database instance: %w", err)
}
Expand Down
143 changes: 90 additions & 53 deletions builtin/logical/database/path_config_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,58 +436,9 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return logical.ErrorResponse(respErrEmptyPluginName), nil
}

if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok {
config.PluginVersion = pluginVersionRaw.(string)
}

var builtinShadowed bool
if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin {
builtinShadowed = true
}
switch {
case config.PluginVersion != "":
semanticVersion, err := version.NewVersion(config.PluginVersion)
if err != nil {
return logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil
}

// Canonicalize the version.
config.PluginVersion = "v" + semanticVersion.String()

if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) {
if builtinShadowed {
return logical.ErrorResponse("database plugin %q, version %s not found, as it is"+
" overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil
}

config.PluginVersion = ""
}
case builtinShadowed:
// We'll select the unversioned plugin that's been registered.
case req.Operation == logical.CreateOperation:
// No version provided and no unversioned plugin of that name available.
// Pin to the current latest version if any versioned plugins are registered.
plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase)
if err != nil {
return nil, err
}

var versionedCandidates []pluginutil.VersionedPlugin
for _, plugin := range plugins {
if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" {
versionedCandidates = append(versionedCandidates, plugin)
}
}

if len(versionedCandidates) != 0 {
// Sort in reverse order.
sort.SliceStable(versionedCandidates, func(i, j int) bool {
return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion)
})

config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String()
b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates))
}
pluginVersion, respErr, err := b.selectPluginVersion(ctx, config, data, req.Operation)
if respErr != nil || err != nil {
return respErr, err
}

if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok {
Expand Down Expand Up @@ -536,7 +487,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
}

// Create a database plugin and initialize it.
dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger)
dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse("error creating database object: %s", err), nil
}
Expand Down Expand Up @@ -613,6 +564,92 @@ func storeConfig(ctx context.Context, storage logical.Storage, name string, conf
return nil
}

func (b *databaseBackend) getPinnedVersion(ctx context.Context, pluginName string) (string, error) {
extendedSys, ok := b.System().(logical.ExtendedSystemView)
if !ok {
return "", fmt.Errorf("database backend does not support running as an external plugin")
}

pin, err := extendedSys.GetPinnedPluginVersion(ctx, consts.PluginTypeDatabase, pluginName)
if errors.Is(err, pluginutil.ErrPinnedVersionNotFound) {
return "", nil
}
if err != nil {
return "", err
}

return pin.Version, nil
}

func (b *databaseBackend) selectPluginVersion(ctx context.Context, config *DatabaseConfig, data *framework.FieldData, op logical.Operation) (string, *logical.Response, error) {
pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName)
if err != nil {
return "", nil, err
}
pluginVersionRaw, ok := data.GetOk("plugin_version")

switch {
case ok && pinnedVersion != "":
return "", logical.ErrorResponse("cannot specify plugin_version for plugin %q as it is pinned (v%s)", config.PluginName, pinnedVersion), nil
case pinnedVersion != "":
return pinnedVersion, nil, nil
case ok:
config.PluginVersion = pluginVersionRaw.(string)
}

var builtinShadowed bool
if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin {
builtinShadowed = true
}
switch {
case config.PluginVersion != "":
semanticVersion, err := version.NewVersion(config.PluginVersion)
if err != nil {
return "", logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil
}

// Canonicalize the version.
config.PluginVersion = "v" + semanticVersion.String()

if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) {
if builtinShadowed {
return "", logical.ErrorResponse("database plugin %q, version %s not found, as it is"+
" overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil
}

config.PluginVersion = ""
}
case builtinShadowed:
// We'll select the unversioned plugin that's been registered.
case op == logical.CreateOperation:
// No version provided and no unversioned plugin of that name available.
// Pin to the current latest version if any versioned plugins are registered.
plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase)
if err != nil {
return "", nil, err
}

var versionedCandidates []pluginutil.VersionedPlugin
for _, plugin := range plugins {
if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" {
versionedCandidates = append(versionedCandidates, plugin)
}
}

if len(versionedCandidates) != 0 {
// Sort in reverse order.
sort.SliceStable(versionedCandidates, func(i, j int) bool {
return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion)
})

config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String()
b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates))
}
}

return config.PluginVersion, nil, nil
}

const pathConfigConnectionHelpSyn = `
Configure connection details to a database plugin.
`
Expand Down
10 changes: 10 additions & 0 deletions sdk/helper/pluginutil/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package pluginutil

import (
"context"
"errors"
"strings"
"time"

Expand All @@ -17,6 +18,9 @@ import (
"google.golang.org/grpc"
)

// ErrPluginNotFound is returned when a plugin does not have a pinned version.
var ErrPinnedVersionNotFound = errors.New("pinned version not found")

// Looker defines the plugin Lookup function that looks into the plugin catalog
// for available plugins and returns a PluginRunner
type Looker interface {
Expand Down Expand Up @@ -144,6 +148,12 @@ type VersionedPlugin struct {
SemanticVersion *version.Version `json:"-"`
}

type PinnedVersion struct {
Name string `json:"name"`
Type consts.PluginType `json:"type"`
Version string `json:"version"`
}

// CtxCancelIfCanceled takes a context cancel func and a context. If the context is
// shutdown the cancelfunc is called. This is useful for merging two cancel
// functions.
Expand Down
3 changes: 3 additions & 0 deletions sdk/logical/system_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ type ExtendedSystemView interface {
// APILockShouldBlockRequest returns whether a namespace for the requested
// mount is locked and should be blocked
APILockShouldBlockRequest() (bool, error)

// GetPinnedPluginVersion returns the pinned version for the given plugin, if any.
GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error)
}

type PasswordGenerator func() (password string, err error)
Expand Down
75 changes: 36 additions & 39 deletions vault/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
var backend logical.Backend
// Create the new backend
sysView := c.mountEntrySysView(entry)
backend, entry.RunningSha256, err = c.newCredentialBackend(ctx, entry, sysView, view)
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
if err != nil {
return err
}
Expand All @@ -188,14 +188,6 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
if backendType != logical.TypeCredential {
return fmt.Errorf("cannot mount %q of type %q as an auth backend", entry.Type, backendType)
}
// update the entry running version with the configured version, which was verified during registration.
entry.RunningVersion = entry.Version
if entry.RunningVersion == "" {
// don't set the running version to a builtin if it is running as an external plugin
if entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type)
}
}
addPathCheckers(c, entry, backend, viewPath)

// If the mount is filtered or we are on a DR secondary we don't want to
Expand Down Expand Up @@ -249,7 +241,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
}

if c.logger.IsInfo() {
c.logger.Info("enabled credential backend", "path", entry.Path, "type", entry.Type, "version", entry.Version)
c.logger.Info("enabled credential backend", "path", entry.Path, "type", entry.Type, "version", entry.RunningVersion)
}
return nil
}
Expand Down Expand Up @@ -805,29 +797,24 @@ func (c *Core) setupCredentials(ctx context.Context) error {
// Initialize the backend
sysView := c.mountEntrySysView(entry)

backend, entry.RunningSha256, err = c.newCredentialBackend(ctx, entry, sysView, view)
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
if err != nil {
c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err)

if c.isMountable(ctx, entry, consts.PluginTypeCredential) {
mountable, checkErr := c.isMountable(ctx, entry, consts.PluginTypeSecrets)
if checkErr != nil {
return errors.Join(errLoadMountsFailed, checkErr, err)
}
if mountable {
c.logger.Warn("skipping plugin-based auth entry", "path", entry.Path)
goto ROUTER_MOUNT
}
return errLoadAuthFailed
return errors.Join(errLoadAuthFailed, err)
}
if backend == nil {
return fmt.Errorf("nil backend returned from %q factory", entry.Type)
}

// update the entry running version with the configured version, which was verified during registration.
entry.RunningVersion = entry.Version
if entry.RunningVersion == "" {
// don't set the running version to a builtin if it is running as an external plugin
if entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type)
}
}

// Do not start up deprecated builtin plugins. If this is a major
// upgrade, stop unsealing and shutdown. If we've already mounted this
// plugin, skip backend initialization and mount the data for posterity.
Expand Down Expand Up @@ -952,34 +939,37 @@ func (c *Core) teardownCredentials(ctx context.Context) error {
}

// newCredentialBackend is used to create and configure a new credential backend by name.
// It also returns the SHA256 of the plugin, if available.
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) {
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) {
t := entry.Type
if alias, ok := credentialAliases[t]; ok {
t = alias
}

pluginVersion, err := c.resolveMountEntryVersion(ctx, consts.PluginTypeCredential, entry)
if err != nil {
return nil, err
}
var runningSha string
f, ok := c.credentialBackends[t]
factory, ok := c.credentialBackends[t]
if !ok {
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, entry.Version)
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, pluginVersion)
if err != nil {
return nil, "", err
return nil, err
}
if plug == nil {
errContext := t
if entry.Version != "" {
errContext += fmt.Sprintf(", version=%s", entry.Version)
if pluginVersion != "" {
errContext += fmt.Sprintf(", version=%s", pluginVersion)
}
return nil, "", fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext)
return nil, fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext)
}
if len(plug.Sha256) > 0 {
runningSha = hex.EncodeToString(plug.Sha256)
}

f = plugin.Factory
factory = plugin.Factory
if !plug.Builtin {
f = wrapFactoryCheckPerms(c, plugin.Factory)
factory = wrapFactoryCheckPerms(c, plugin.Factory)
}
}
// Set up conf to pass in plugin_name
Expand All @@ -996,7 +986,7 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
}

conf["plugin_type"] = consts.PluginTypeCredential.String()
conf["plugin_version"] = entry.Version
conf["plugin_version"] = pluginVersion

authLogger := c.baseLogger.Named(fmt.Sprintf("auth.%s.%s", t, entry.Accessor))
c.AddLogger(authLogger)
Expand All @@ -1005,11 +995,11 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
MountAccessor: entry.Accessor,
MountPath: entry.Path,
Plugin: entry.Type,
PluginVersion: entry.RunningVersion,
Version: entry.Version,
PluginVersion: pluginVersion,
Version: entry.Options["version"],
})
if err != nil {
return nil, "", err
return nil, err
}

config := &logical.BackendConfig{
Expand All @@ -1021,12 +1011,19 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
EventsSender: pluginEventSender,
}

b, err := f(ctx, config)
backend, err := factory(ctx, config)
if err != nil {
return nil, "", err
return nil, err
}
if backend != nil {
entry.RunningVersion = pluginVersion
entry.RunningSha256 = runningSha
if entry.RunningVersion == "" && entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type)
}
}

return b, runningSha, nil
return backend, nil
}

func wrapFactoryCheckPerms(core *Core, f logical.Factory) logical.Factory {
Expand Down
Loading

0 comments on commit af27ab3

Please sign in to comment.