From d8dca0595058b992fc0babebf4dcd8602b6575fb Mon Sep 17 00:00:00 2001 From: hc-github-team-secure-vault-core <82990506+hc-github-team-secure-vault-core@users.noreply.github.com> Date: Fri, 24 May 2024 13:27:37 -0700 Subject: [PATCH] Backport of Improve IdentityStore Invalidate performance into release/1.16.x (#27230) * Improve IdentityStore Invalidate performance (#27184) * improve identitystore invalidate performance * add changelog * adding test to cover invalidation of entity bucket keys within IdentityStore * minor clean ups * adding tests * add missing godoc for tests * fix incorrect merge resolution --------- Co-authored-by: Marc Boudreau --- changelog/27184.txt | 3 + vault/identity_store.go | 611 +++++++++++++++++++++-------------- vault/identity_store_test.go | 215 ++++++++++++ vault/identity_store_util.go | 30 ++ 4 files changed, 623 insertions(+), 236 deletions(-) create mode 100644 changelog/27184.txt diff --git a/changelog/27184.txt b/changelog/27184.txt new file mode 100644 index 000000000000..500045efb5af --- /dev/null +++ b/changelog/27184.txt @@ -0,0 +1,3 @@ +```release-note:change +core/identity: improve performance for secondary nodes receiving identity related updates through replication +``` diff --git a/vault/identity_store.go b/vault/identity_store.go index c10edf7ad368..8d53f4c35682 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -6,6 +6,7 @@ package vault import ( "context" "fmt" + "reflect" "strings" "time" @@ -24,6 +25,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/locksutil" "github.com/hashicorp/vault/sdk/logical" "github.com/patrickmn/go-cache" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -621,316 +623,453 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) { defer i.lock.Unlock() switch { - // Check if the key is a storage entry key for an entity bucket case strings.HasPrefix(key, storagepacker.StoragePackerBucketsPrefix): - // Create a MemDB transaction - txn := i.db.Txn(true) - defer txn.Abort() - - // Each entity object in MemDB holds the MD5 hash of the storage - // entry key of the entity bucket. Fetch all the entities that - // belong to this bucket using the hash value. Remove these entities - // from MemDB along with all the aliases of each entity. - entitiesFetched, err := i.MemDBEntitiesByBucketKeyInTxn(txn, key) - if err != nil { - i.logger.Error("failed to fetch entities using the bucket key", "key", key) - return - } + // key is for a entity bucket in storage. + i.invalidateEntityBucket(ctx, key) + case strings.HasPrefix(key, groupBucketsPrefix): + // key is for a group bucket in storage. + i.invalidateGroupBucket(ctx, key) + case strings.HasPrefix(key, oidcTokensPrefix): + // key is for oidc tokens in storage. + i.invalidateOIDCToken(ctx) + case strings.HasPrefix(key, clientPath): + // key is for a client in storage. + i.invalidateClientPath(ctx, key) + case strings.HasPrefix(key, localAliasesBucketsPrefix): + // key is for a local alias bucket in storage. + i.invalidateLocalAliasesBucket(ctx, key) + } +} + +func (i *IdentityStore) invalidateEntityBucket(ctx context.Context, key string) { + txn := i.db.Txn(true) + defer txn.Abort() - for _, entity := range entitiesFetched { - // Delete all the aliases in the entity. This function will also remove - // the corresponding alias indexes too. - err = i.deleteAliasesInEntityInTxn(txn, entity, entity.Aliases) + // The handling of entities has the added quirk of dealing with a temporary + // copy of the entity written in storage on the active node of performance + // secondary clusters. These temporary entity entries in storage must be + // removed once the actual entity appears in the storage bucket (as + // replicated from the primary cluster). + // + // This function retrieves all entities from MemDB that have a corresponding + // storage key that matches the provided key to invalidate. This is the set + // of entities that need to be updated, removed, or left alone in MemDB. + // + // The logic iterates over every entity stored in the invalidated storage + // bucket. For each entity read from the storage bucket, the set of entities + // read from MemDB is searched for the same entity. If it can't be found, + // it means that it needs to be inserted into MemDB. On the other hand, if + // the entity is found, it the storage bucket entity is compared to the + // MemDB entity. If they do not match, then the storage entity state needs + // to be used to update the MemDB entity; if they did match, then it means + // that the MemDB entity can be left alone. As each MemDB entity is + // processed in the loop, it is removed from the set of MemDB entities. + // + // Once all entities from the storage bucket have been compared to those + // retrieved from MemDB, the remaining entities from the set retrieved from + // MemDB are those that have been deleted from storage and must be removed + // from MemDB (because as MemDB entities that matches a storage bucket + // entity were processed, they were removed from the set). + memDBEntities, err := i.MemDBEntitiesByBucketKeyInTxn(txn, key) + if err != nil { + i.logger.Error("failed to fetch entities using the bucket key", "key", key) + return + } + + bucket, err := i.entityPacker.GetBucket(ctx, key) + if err != nil { + i.logger.Error("failed to refresh entities", "key", key, "error", err) + return + } + + if bucket != nil { + // The storage entry for the entity bucket exists, so we need to compare + // the entities in that bucket with those in MemDB and only update those + // that are different. The entities in the bucket storage entry are the + // source of truth. + + // Iterate over each entity item from the bucket + for _, item := range bucket.Items { + bucketEntity, err := i.parseEntityFromBucketItem(ctx, item) if err != nil { - i.logger.Error("failed to delete aliases in entity", "entity_id", entity.ID, "error", err) + i.logger.Error("failed to parse entity from bucket entry item", "error", err) return } - // Delete the entity using the same transaction - err = i.MemDBDeleteEntityByIDInTxn(txn, entity.ID) + localAliases, err := i.parseLocalAliases(bucketEntity.ID) if err != nil { - i.logger.Error("failed to delete entity from MemDB", "entity_id", entity.ID, "error", err) + i.logger.Error("failed to load local aliases from storage", "error", err) return } - } - // Get the storage bucket entry - bucket, err := i.entityPacker.GetBucket(ctx, key) - if err != nil { - i.logger.Error("failed to refresh entities", "key", key, "error", err) - return - } - - // If the underlying entry is nil, it means that this invalidation - // notification is for the deletion of the underlying storage entry. At - // this point, since all the entities belonging to this bucket are - // already removed, there is nothing else to be done. But, if the - // storage entry is non-nil, its an indication of an update. In this - // case, entities in the updated bucket needs to be reinserted into - // MemDB. - var entityIDs []string - if bucket != nil { - entityIDs = make([]string, 0, len(bucket.Items)) - for _, item := range bucket.Items { - entity, err := i.parseEntityFromBucketItem(ctx, item) - if err != nil { - i.logger.Error("failed to parse entity from bucket entry item", "error", err) - return + if localAliases != nil { + for _, alias := range localAliases.Aliases { + bucketEntity.UpsertAlias(alias) } + } - localAliases, err := i.parseLocalAliases(entity.ID) - if err != nil { - i.logger.Error("failed to load local aliases from storage", "error", err) - return - } - if localAliases != nil { - for _, alias := range localAliases.Aliases { - entity.UpsertAlias(alias) - } + var memDBEntity *identity.Entity + for i, entity := range memDBEntities { + if entity.ID == bucketEntity.ID { + memDBEntity = entity + + // Remove this processed entity from the slice, so that + // all tht will be left are unprocessed entities. + copy(memDBEntities[i:], memDBEntities[i+1:]) + memDBEntities = memDBEntities[:len(memDBEntities)-1] + break } + } + + // If the entity is not in MemDB or if it is but differs from the + // state that's in the bucket storage entry, upsert it into MemDB. - // Only update MemDB and don't touch the storage - err = i.upsertEntityInTxn(ctx, txn, entity, nil, false) + // We've considered the use of github.com/google/go-cmp here, + // but opted for sticking with reflect.DeepEqual because go-cmp + // is intended for testing and is able to panic in some + // situations. + if memDBEntity == nil || !reflect.DeepEqual(memDBEntity, bucketEntity) { + // The entity is not in MemDB, it's a new entity. Add it to MemDB. + err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false) if err != nil { - i.logger.Error("failed to update entity in MemDB", "error", err) + i.logger.Error("failed to update entity in MemDB", "entity_id", bucketEntity.ID, "error", err) return } - // If we are a secondary, the entity created by the secondary - // via the CreateEntity RPC would have been cached. Now that the - // invalidation of the same has hit, there is no need of the - // cache. Clearing the cache. Writing to storage can't be - // performed by perf standbys. So only doing this in the active - // node of the secondary. + // If this is a performance secondary, the entity created on + // this node would have been cached in a local cache based on + // the result of the CreateEntity RPC call to the primary + // cluster. Since this invalidation is signaling that the + // entity is now in the primary cluster's storage, the locally + // cached entry can be removed. if i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) && i.localNode.HAState() == consts.Active { - if err := i.localAliasPacker.DeleteItem(ctx, entity.ID+tmpSuffix); err != nil { - i.logger.Error("failed to clear local alias entity cache", "error", err, "entity_id", entity.ID) + if err := i.localAliasPacker.DeleteItem(ctx, bucketEntity.ID+tmpSuffix); err != nil { + i.logger.Error("failed to clear local alias entity cache", "error", err, "entity_id", bucketEntity.ID) return } } - - entityIDs = append(entityIDs, entity.ID) } } + } + + // Any entities that are still in the memDBEntities slice are ones that do + // not exist in the bucket storage entry. These entities have to be removed + // from MemDB. + for _, memDBEntity := range memDBEntities { + err = i.deleteAliasesInEntityInTxn(txn, memDBEntity, memDBEntity.Aliases) + if err != nil { + i.logger.Error("failed to delete aliases in entity", "entity_id", memDBEntity.ID, "error", err) + return + } + + err = i.MemDBDeleteEntityByIDInTxn(txn, memDBEntity.ID) + if err != nil { + i.logger.Error("failed to delete entity from MemDB", "entity_id", memDBEntity.ID, "error", err) + return + } - // entitiesFetched are the entities before invalidation. entityIDs - // represent entities that are valid after invalidation. Clear the - // storage entries of local aliases for those entities that are - // indicated deleted by this invalidation. + // In addition, if this is an active node of a performance secondary + // cluster, remove the local alias storage entry for this deleted entity. if i.localNode.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) && i.localNode.HAState() == consts.Active { - for _, entity := range entitiesFetched { - if !strutil.StrListContains(entityIDs, entity.ID) { - if err := i.localAliasPacker.DeleteItem(ctx, entity.ID); err != nil { - i.logger.Error("failed to clear local alias for entity", "error", err, "entity_id", entity.ID) - return - } - } + if err := i.localAliasPacker.DeleteItem(ctx, memDBEntity.ID); err != nil { + i.logger.Error("failed to clear local alias for entity", "error", err, "entity_id", memDBEntity.ID) + return } } + } - txn.Commit() - return + txn.Commit() +} - // Check if the key is a storage entry key for an group bucket - // For those entities that are deleted, clear up the local alias entries - case strings.HasPrefix(key, groupBucketsPrefix): - // Create a MemDB transaction - txn := i.db.Txn(true) - defer txn.Abort() +func (i *IdentityStore) invalidateGroupBucket(ctx context.Context, key string) { + // Create a MemDB transaction + txn := i.db.Txn(true) + defer txn.Abort() + + groupsFetched, err := i.MemDBGroupsByBucketKeyInTxn(txn, key) + if err != nil { + i.logger.Error("failed to fetch groups using the bucket key", "key", key) + return + } - groupsFetched, err := i.MemDBGroupsByBucketKeyInTxn(txn, key) + for _, group := range groupsFetched { + // Delete the group using the same transaction + err = i.MemDBDeleteGroupByIDInTxn(txn, group.ID) if err != nil { - i.logger.Error("failed to fetch groups using the bucket key", "key", key) + i.logger.Error("failed to delete group from MemDB", "group_id", group.ID, "error", err) return } - for _, group := range groupsFetched { - // Delete the group using the same transaction - err = i.MemDBDeleteGroupByIDInTxn(txn, group.ID) + if group.Alias != nil { + err := i.MemDBDeleteAliasByIDInTxn(txn, group.Alias.ID, true) if err != nil { - i.logger.Error("failed to delete group from MemDB", "group_id", group.ID, "error", err) + i.logger.Error("failed to delete group alias from MemDB", "error", err) return } + } + } + + // Get the storage bucket entry + bucket, err := i.groupPacker.GetBucket(ctx, key) + if err != nil { + i.logger.Error("failed to refresh group", "key", key, "error", err) + return + } - if group.Alias != nil { - err := i.MemDBDeleteAliasByIDInTxn(txn, group.Alias.ID, true) + if bucket != nil { + for _, item := range bucket.Items { + group, err := i.parseGroupFromBucketItem(item) + if err != nil { + i.logger.Error("failed to parse group from bucket entry item", "error", err) + return + } + + // Before updating the group, check if the group exists. If it + // does, then delete the group alias from memdb, for the + // invalidation would have sent an update. + groupFetched, err := i.MemDBGroupByIDInTxn(txn, group.ID, true) + if err != nil { + i.logger.Error("failed to fetch group from MemDB", "error", err) + return + } + + // If the group has an alias remove it from memdb + if groupFetched != nil && groupFetched.Alias != nil { + err := i.MemDBDeleteAliasByIDInTxn(txn, groupFetched.Alias.ID, true) if err != nil { - i.logger.Error("failed to delete group alias from MemDB", "error", err) + i.logger.Error("failed to delete old group alias from MemDB", "error", err) return } } - } - // Get the storage bucket entry - bucket, err := i.groupPacker.GetBucket(ctx, key) - if err != nil { - i.logger.Error("failed to refresh group", "key", key, "error", err) - return + // Only update MemDB and don't touch the storage + err = i.UpsertGroupInTxn(ctx, txn, group, false) + if err != nil { + i.logger.Error("failed to update group in MemDB", "error", err) + return + } } + } - if bucket != nil { - for _, item := range bucket.Items { - group, err := i.parseGroupFromBucketItem(item) - if err != nil { - i.logger.Error("failed to parse group from bucket entry item", "error", err) - return - } + txn.Commit() +} - // Before updating the group, check if the group exists. If it - // does, then delete the group alias from memdb, for the - // invalidation would have sent an update. - groupFetched, err := i.MemDBGroupByIDInTxn(txn, group.ID, true) - if err != nil { - i.logger.Error("failed to fetch group from MemDB", "error", err) - return - } +// invalidateOIDCToken is called by the Invalidate function to handle the +// invalidation of an OIDC token storage entry. +func (i *IdentityStore) invalidateOIDCToken(ctx context.Context) { + ns, err := namespace.FromContext(ctx) + if err != nil { + i.logger.Error("error retrieving namespace", "error", err) + return + } - // If the group has an alias remove it from memdb - if groupFetched != nil && groupFetched.Alias != nil { - err := i.MemDBDeleteAliasByIDInTxn(txn, groupFetched.Alias.ID, true) - if err != nil { - i.logger.Error("failed to delete old group alias from MemDB", "error", err) - return - } - } + // Wipe the cache for the requested namespace. This will also clear + // the shared namespace as well. + if err := i.oidcCache.Flush(ns); err != nil { + i.logger.Error("error flushing oidc cache", "error", err) + return + } +} - // Only update MemDB and don't touch the storage - err = i.UpsertGroupInTxn(ctx, txn, group, false) - if err != nil { - i.logger.Error("failed to update group in MemDB", "error", err) - return - } - } - } +// invalidateClientPath is called by the Invalidate function to handle the +// invalidation of a client path storage entry. +func (i *IdentityStore) invalidateClientPath(ctx context.Context, key string) { + name := strings.TrimPrefix(key, clientPath) - txn.Commit() + // Invalidate the cached client in memdb + if err := i.memDBDeleteClientByName(ctx, name); err != nil { + i.logger.Error("error invalidating client", "error", err, "key", key) return + } +} - case strings.HasPrefix(key, oidcTokensPrefix): - ns, err := namespace.FromContext(ctx) - if err != nil { - i.logger.Error("error retrieving namespace", "error", err) - return - } +// invalidateLocalAliasBucket is called by the Invalidate function to handle the +// invalidation of a local alias bucket storage entry. +func (i *IdentityStore) invalidateLocalAliasesBucket(ctx context.Context, key string) { + // This invalidation only happens on performance standby servers - // Wipe the cache for the requested namespace. This will also clear - // the shared namespace as well. - if err := i.oidcCache.Flush(ns); err != nil { - i.logger.Error("error flushing oidc cache", "error", err) - } - case strings.HasPrefix(key, clientPath): - name := strings.TrimPrefix(key, clientPath) + // Create a MemDB transaction and abort it once this function returns + txn := i.db.Txn(true) + defer txn.Abort() - // Invalidate the cached client in memdb - if err := i.memDBDeleteClientByName(ctx, name); err != nil { - i.logger.Error("error invalidating client", "error", err, "key", key) - return - } - case strings.HasPrefix(key, localAliasesBucketsPrefix): - // - // This invalidation only happens on perf standbys - // - - txn := i.db.Txn(true) - defer txn.Abort() - - // Find all the local aliases belonging to this bucket and remove it - // both from aliases table and entities table. We will add the local - // aliases back by parsing the storage key. This way the deletion - // invalidation gets handled. - aliases, err := i.MemDBLocalAliasesByBucketKeyInTxn(txn, key) - if err != nil { - i.logger.Error("failed to fetch entities using the bucket key", "key", key) - return - } + // Local aliases have the added complexity of being associated with + // entities. Whenever a local alias is updated or inserted into MemDB, its + // associated MemDB-stored entity must also be updated. + // + // This function retrieves all local aliases that have a corresponding + // storage key that matches the provided key to invalidate. This is the + // set of local aliases that need to be updated, removed, or left + // alone in MemDB. Each of these operations is done as its own MemDB + // operation, but the corresponding changes that need to be made to the + // associated entities can be batched together to cut down on the number of + // MemDB operations. + // + // The logic iterates over every local alias stored at the invalidated key. + // For each local alias read from the storage entry, the set of local + // aliases read from MemDB is searched for the same local alias. If it can't + // be found, it means that it needs to be inserted into MemDB. However, if + // it's found, it must be compared with the local alias from the storage. If + // they don't match, it means that the local alias in MemDB needs to be + // updated. If they did match, it means that this particular local alias did + // not change in storage, so nothing further needs to be done. Each local + // alias processed in this loop is removed from the set of retrieved local + // aliases. The local alias is also added to the map tracking local aliases + // that need to be upserted in their associated entities in MemDB. + // + // Once the code is done iterating over all of the local aliases from + // storage, any local aliases still in the set retrieved from MemDB + // corresponds to a local alias that is no longer in storage and must be + // removed from MemDB. These local aliases are added to the map tracking + // local aliases to remove from their entities in MemDB. The actual removal + // of the local aliases themselves is done as part of the tidying up of the + // associated entities, described below. + // + // In order to batch the changes to the associated entities, a map of entity + // to local aliases (slice of local alias) is built up in the loop that + // iterates over the local aliases from storage. Similarly, the code that + // detects which local aliases to remove from MemDB also builds a separate + // map of entity to local aliases (slice of local alias). Each element in + // the map of local aliases to update in their entity is processed as + // follows: the mapped slice of local aliases is iterated over and each + // local alias is upserted into the entity and then the entity itself is + // upserted. Then, each element in the map of local aliases to remove from + // their entity is processed as follows: the + + // Get all cached local aliases to compare with invalidated bucket + memDBLocalAliases, err := i.MemDBLocalAliasesByBucketKeyInTxn(txn, key) + if err != nil { + i.logger.Error("failed to fetch local aliases using the bucket key", "key", key, "error", err) + return + } - for _, alias := range aliases { - entity, err := i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, true) - if err != nil { - i.logger.Error("failed to fetch entity during local alias invalidation", "entity_id", alias.CanonicalID, "error", err) - return - } - if entity == nil { - i.logger.Error("failed to fetch entity during local alias invalidation, missing entity", "entity_id", alias.CanonicalID, "error", err) + // Get local aliases from the invalidated bucket + bucket, err := i.localAliasPacker.GetBucket(ctx, key) + if err != nil { + i.logger.Error("failed to refresh local aliases", "key", key, "error", err) + return + } + + // This map tracks the set of local aliases that need to be updated in each + // affected entity in MemDB. + entityLocalAliasesToUpsert := map[*identity.Entity][]*identity.Alias{} + + // This map tracks the set of local aliases that need to be removed from + // their affected entity in MemDB, as well as removing the local alias + // themselves. + entityLocalAliasesToRemove := map[*identity.Entity][]*identity.Alias{} + + if bucket != nil { + // The storage entry for the local alias bucket exists, so we need to + // compare the local aliases in that bucket with those in MemDB and only + // update those that are different. The local aliases in the bucket are + // the source of truth. + + // Iterate over each local alias item from the bucket + for _, item := range bucket.Items { + if strings.HasSuffix(item.ID, tmpSuffix) { continue } - // Delete local aliases from the entity. - err = i.deleteAliasesInEntityInTxn(txn, entity, []*identity.Alias{alias}) - if err != nil { - i.logger.Error("failed to delete aliases in entity", "entity_id", entity.ID, "error", err) - return - } + var bucketLocalAliases identity.LocalAliases - // Update the entity with removed alias. - if err := i.MemDBUpsertEntityInTxn(txn, entity); err != nil { - i.logger.Error("failed to delete entity from MemDB", "entity_id", entity.ID, "error", err) + err = anypb.UnmarshalTo(item.Message, &bucketLocalAliases, proto.UnmarshalOptions{}) + if err != nil { + i.logger.Error("failed to parse local aliases during invalidation", "item_id", item.ID, "error", err) return } - } - // Now read the invalidated storage key - bucket, err := i.localAliasPacker.GetBucket(ctx, key) - if err != nil { - i.logger.Error("failed to refresh local aliases", "key", key, "error", err) - return - } - if bucket != nil { - for _, item := range bucket.Items { - if strings.HasSuffix(item.ID, tmpSuffix) { - continue - } - - var localAliases identity.LocalAliases - err = ptypes.UnmarshalAny(item.Message, &localAliases) - if err != nil { - i.logger.Error("failed to parse local aliases during invalidation", "error", err) + for _, bucketLocalAlias := range bucketLocalAliases.Aliases { + // Find the entity related to bucketLocalAlias in MemDB in order + // to track any local aliases modifications that must be made in + // this entity. + memDBEntity := i.FetchEntityForLocalAliasInTxn(txn, bucketLocalAlias) + if memDBEntity == nil { + // FetchEntityForLocalAliasInTxn already logs any error return } - for _, alias := range localAliases.Aliases { - // Add to the aliases table - if err := i.MemDBUpsertAliasInTxn(txn, alias, false); err != nil { - i.logger.Error("failed to insert local alias to memdb during invalidation", "error", err) - return + + // memDBLocalAlias starts off nil but gets set to the local + // alias from memDBLocalAliases whose ID matches the ID of + // bucketLocalAlias. + var memDBLocalAlias *identity.Alias + for i, localAlias := range memDBLocalAliases { + if localAlias.ID == bucketLocalAlias.ID { + memDBLocalAlias = localAlias + + // Remove this processed local alias from the + // memDBLocalAliases slice, so that all that + // will be left are unprocessed local aliases. + copy(memDBLocalAliases[i:], memDBLocalAliases[i+1:]) + memDBLocalAliases = memDBLocalAliases[:len(memDBLocalAliases)-1] + + break } + } - // Fetch the associated entity and add the alias to that too. - entity, err := i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, false) + // We've considered the use of github.com/google/go-cmp here, + // but opted for sticking with reflect.DeepEqual because go-cmp + // is intended for testing and is able to panic in some + // situations. + if memDBLocalAlias == nil || !reflect.DeepEqual(memDBLocalAlias, bucketLocalAlias) { + // The bucketLocalAlias is not in MemDB or it has changed in + // storage. + err = i.MemDBUpsertAliasInTxn(txn, bucketLocalAlias, false) if err != nil { - i.logger.Error("failed to fetch entity during local alias invalidation", "error", err) + i.logger.Error("failed to update local alias in MemDB", "alias_id", bucketLocalAlias.ID, "error", err) return } - if entity == nil { - cachedEntityItem, err := i.localAliasPacker.GetItem(alias.CanonicalID + tmpSuffix) - if err != nil { - i.logger.Error("failed to fetch cached entity", "key", key, "error", err) - return - } - if cachedEntityItem != nil { - entity, err = i.parseCachedEntity(cachedEntityItem) - if err != nil { - i.logger.Error("failed to parse cached entity", "key", key, "error", err) - return - } - } - } - if entity == nil { - i.logger.Error("received local alias invalidation for an invalid entity", "item.ID", item.ID) - return - } - entity.UpsertAlias(alias) - // Update the entities table - if err := i.MemDBUpsertEntityInTxn(txn, entity); err != nil { - i.logger.Error("failed to upsert entity during local alias invalidation", "error", err) - return - } + // Add this local alias to the set of local aliases that + // need to be updated for memDBEntity. + entityLocalAliasesToUpsert[memDBEntity] = append(entityLocalAliasesToUpsert[memDBEntity], bucketLocalAlias) } } } - txn.Commit() - return } + + // Any local aliases still remaining in memDBLocalAliases do not exist in + // storage and should be removed from MemDB. + for _, memDBLocalAlias := range memDBLocalAliases { + memDBEntity := i.FetchEntityForLocalAliasInTxn(txn, memDBLocalAlias) + if memDBEntity == nil { + // FetchEntityForLocalAliasInTxn already logs any error + return + } + + entityLocalAliasesToRemove[memDBEntity] = append(entityLocalAliasesToRemove[memDBEntity], memDBLocalAlias) + } + + // Now process the entityLocalAliasesToUpsert map. + for entity, localAliases := range entityLocalAliasesToUpsert { + for _, localAlias := range localAliases { + entity.UpsertAlias(localAlias) + } + + err = i.MemDBUpsertEntityInTxn(txn, entity) + if err != nil { + i.logger.Error("failed to update entity in MemDB", "entity_id", entity.ID, "error", err) + return + } + } + + // Finally process the entityLocalAliasesToRemove map. + for entity, localAliases := range entityLocalAliasesToRemove { + // The deleteAliasesInEntityInTxn removes the provided aliases from + // the entity, but it also removes the aliases themselves from MemDB. + err := i.deleteAliasesInEntityInTxn(txn, entity, localAliases) + if err != nil { + i.logger.Error("failed to delete aliases in entity", "entity_id", entity.ID, "error", err) + return + } + + err = i.MemDBUpsertEntityInTxn(txn, entity) + if err != nil { + i.logger.Error("failed to update entity in MemDB", "entity_id", entity.ID, "error", err) + return + } + } + + txn.Commit() } func (i *IdentityStore) parseLocalAliases(entityID string) (*identity.LocalAliases, error) { diff --git a/vault/identity_store_test.go b/vault/identity_store_test.go index 9ed4659b8d27..7c826dfa0c33 100644 --- a/vault/identity_store_test.go +++ b/vault/identity_store_test.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/storagepacker" "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/anypb" ) @@ -912,3 +913,217 @@ func TestIdentityStore_DeleteCaseSensitivityKey(t *testing.T) { t.Fatalf("bad: expected no entry for casesensitivity key") } } + +// TestIdentityStoreInvalidate_Entities verifies the proper handling of +// entities in the Invalidate method. +func TestIdentityStoreInvalidate_Entities(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + // Create an entity in storage then call the Invalidate function + // + id, err := uuid.GenerateUUID() + require.NoError(t, err) + + entity := &identity.Entity{ + Name: "test", + NamespaceID: namespace.RootNamespaceID, + ID: id, + Aliases: []*identity.Alias{}, + BucketKey: c.identityStore.entityPacker.BucketKey(id), + } + + p := c.identityStore.entityPacker + + // Persist the entity which we are merging to + entityAsAny, err := anypb.New(entity) + require.NoError(t, err) + + item := &storagepacker.Item{ + ID: id, + Message: entityAsAny, + } + + err = p.PutItem(context.Background(), item) + require.NoError(t, err) + + c.identityStore.Invalidate(context.Background(), p.BucketKey(id)) + + txn := c.identityStore.db.Txn(true) + + memEntity, err := c.identityStore.MemDBEntityByIDInTxn(txn, id, true) + assert.NoError(t, err) + assert.NotNil(t, memEntity) + + txn.Commit() + + // Modify the entity in storage then call the Invalidate function + entity.Metadata = make(map[string]string) + entity.Metadata["foo"] = "bar" + + entityAsAny, err = anypb.New(entity) + require.NoError(t, err) + + item.Message = entityAsAny + + p.PutItem(context.Background(), item) + + c.identityStore.Invalidate(context.Background(), p.BucketKey(id)) + + txn = c.identityStore.db.Txn(true) + + memEntity, err = c.identityStore.MemDBEntityByIDInTxn(txn, id, true) + assert.NoError(t, err) + assert.Contains(t, memEntity.Metadata, "foo") + + txn.Commit() + + // Delete the entity in storage then call the Invalidate function + err = p.DeleteItem(context.Background(), id) + require.NoError(t, err) + + c.identityStore.Invalidate(context.Background(), p.BucketKey(id)) + + txn = c.identityStore.db.Txn(true) + + memEntity, err = c.identityStore.MemDBEntityByIDInTxn(txn, id, true) + assert.NoError(t, err) + assert.Nil(t, memEntity) + + txn.Commit() +} + +// TestIdentityStoreInvalidate_LocalAliasesWithEntity verifies the correct +// handling of local aliases in the Invalidate method. +func TestIdentityStoreInvalidate_LocalAliasesWithEntity(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + // Create an entity in storage then call the Invalidate function + // + entityID, err := uuid.GenerateUUID() + require.NoError(t, err) + + entity := &identity.Entity{ + Name: "test", + NamespaceID: namespace.RootNamespaceID, + ID: entityID, + Aliases: []*identity.Alias{}, + BucketKey: c.identityStore.entityPacker.BucketKey(entityID), + } + + aliasID, err := uuid.GenerateUUID() + require.NoError(t, err) + + localAliases := &identity.LocalAliases{ + Aliases: []*identity.Alias{ + { + ID: aliasID, + Name: "test", + NamespaceID: namespace.RootNamespaceID, + CanonicalID: entityID, + MountAccessor: "userpass-000000", + }, + }, + } + + ep := c.identityStore.entityPacker + + // Persist the entity which we are merging to + entityAsAny, err := anypb.New(entity) + require.NoError(t, err) + + entityItem := &storagepacker.Item{ + ID: entityID, + Message: entityAsAny, + } + + err = ep.PutItem(context.Background(), entityItem) + require.NoError(t, err) + + c.identityStore.Invalidate(context.Background(), ep.BucketKey(entityID)) + + lap := c.identityStore.localAliasPacker + + localAliasesAsAny, err := anypb.New(localAliases) + require.NoError(t, err) + + localAliasesItem := &storagepacker.Item{ + ID: entityID, + Message: localAliasesAsAny, + } + + err = lap.PutItem(context.Background(), localAliasesItem) + require.NoError(t, err) + + c.identityStore.Invalidate(context.Background(), lap.BucketKey(entityID)) + + txn := c.identityStore.db.Txn(true) + + memDBEntity, err := c.identityStore.MemDBEntityByIDInTxn(txn, entityID, true) + assert.NoError(t, err) + assert.NotNil(t, memDBEntity) + + memDBLocalAlias, err := c.identityStore.MemDBAliasByIDInTxn(txn, aliasID, true, false) + assert.NoError(t, err) + assert.NotNil(t, memDBLocalAlias) + assert.Equal(t, 1, len(memDBEntity.Aliases)) + assert.NotNil(t, memDBEntity.Aliases[0]) + assert.Equal(t, memDBEntity.Aliases[0].ID, memDBLocalAlias.ID) + + txn.Commit() +} + +// TestIdentityStoreInvalidate_TemporaryEntity verifies the proper handling of +// temporary entities in the Invalidate method. +func TestIdentityStoreInvalidate_TemporaryEntity(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + // Create an entity in storage then call the Invalidate function + // + entityID, err := uuid.GenerateUUID() + require.NoError(t, err) + + tempEntity := &identity.Entity{ + Name: "test", + NamespaceID: namespace.RootNamespaceID, + ID: entityID, + Aliases: []*identity.Alias{}, + BucketKey: c.identityStore.entityPacker.BucketKey(entityID), + } + + lap := c.identityStore.localAliasPacker + ep := c.identityStore.entityPacker + + // Persist the entity which we are merging to + tempEntityAsAny, err := anypb.New(tempEntity) + require.NoError(t, err) + + tempEntityItem := &storagepacker.Item{ + ID: entityID + tmpSuffix, + Message: tempEntityAsAny, + } + + err = lap.PutItem(context.Background(), tempEntityItem) + require.NoError(t, err) + + entityAsAny := tempEntityAsAny + + entityItem := &storagepacker.Item{ + ID: entityID, + Message: entityAsAny, + } + + err = ep.PutItem(context.Background(), entityItem) + require.NoError(t, err) + + c.identityStore.Invalidate(context.Background(), ep.BucketKey(entityID)) + + txn := c.identityStore.db.Txn(true) + + memDBEntity, err := c.identityStore.MemDBEntityByIDInTxn(txn, entityID, true) + assert.NoError(t, err) + assert.NotNil(t, memDBEntity) + + item, err := lap.GetItem(lap.BucketKey(entityID) + tmpSuffix) + assert.NoError(t, err) + assert.Nil(t, item) +} diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 6d9190cbe293..c78db0bc70f7 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -1269,6 +1269,36 @@ func (i *IdentityStore) MemDBDeleteEntityByID(entityID string) error { return nil } +// FetchEntityForLocalAliasInTxn fetches the entity associated with the provided +// local identity.Alias. MemDB will first be searched for the entity. If it is +// not found there, the localAliasPacker storagepacker.StoragePacker will be +// used. If an error occurs, an appropriate error message is logged and nil is +// returned. +func (i *IdentityStore) FetchEntityForLocalAliasInTxn(txn *memdb.Txn, alias *identity.Alias) *identity.Entity { + entity, err := i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, false) + if err != nil { + i.logger.Error("failed to fetch entity from local alias", "entity_id", alias.CanonicalID, "error", err) + return nil + } + + if entity == nil { + cachedEntityItem, err := i.localAliasPacker.GetItem(alias.CanonicalID + tmpSuffix) + if err != nil { + i.logger.Error("failed to fetch cached entity from local alias", "key", alias.CanonicalID+tmpSuffix, "error", err) + return nil + } + if cachedEntityItem != nil { + entity, err = i.parseCachedEntity(cachedEntityItem) + if err != nil { + i.logger.Error("failed to parse cached entity", "key", alias.CanonicalID+tmpSuffix, "error", err) + return nil + } + } + } + + return entity +} + func (i *IdentityStore) MemDBDeleteEntityByIDInTxn(txn *memdb.Txn, entityID string) error { if entityID == "" { return nil