diff --git a/controller/handler_edge_ctrl/common_tunnel.go b/controller/handler_edge_ctrl/common_tunnel.go index 22102e298..34c266850 100644 --- a/controller/handler_edge_ctrl/common_tunnel.go +++ b/controller/handler_edge_ctrl/common_tunnel.go @@ -12,6 +12,7 @@ import ( "github.com/openziti/foundation/v2/concurrenz" "github.com/openziti/storage/boltz" "github.com/sirupsen/logrus" + "go.etcd.io/bbolt" "sync" "time" ) @@ -131,6 +132,18 @@ func (self *baseTunnelRequestContext) ensureApiSessionLocking(configTypes []stri } } + identityMgr := self.handler.getAppEnv().Managers.Identity + if cachedApiSessionId, _ := identityMgr.GetAnnotation(self.identity.Id, "apiSessionId"); cachedApiSessionId != nil { + apiSession, _ := self.handler.getAppEnv().Managers.ApiSession.Read(*cachedApiSessionId) + if apiSession != nil && apiSession.IdentityId == self.identity.Id { + self.apiSession = apiSession + if _, err := self.handler.getAppEnv().GetManagers().ApiSession.MarkActivityByTokens(self.apiSession.Token); err != nil { + logger.WithError(err).Error("unexpected error while marking api session activity") + } + return true + } + } + apiSession := &model.ApiSession{ Token: uuid.NewString(), IdentityId: self.identity.Id, @@ -139,14 +152,23 @@ func (self *baseTunnelRequestContext) ensureApiSessionLocking(configTypes []stri IPAddress: self.handler.getChannel().Underlay().GetRemoteAddr().String(), } - var err error - apiSession.Id, err = self.handler.getAppEnv().GetManagers().ApiSession.Create(apiSession, nil) - if err != nil { - self.err = internalError(err) - return false - } + err := self.handler.getAppEnv().GetDbProvider().GetDb().Update(func(tx *bbolt.Tx) error { + ctx := boltz.NewMutateContext(tx) + + var err error + apiSession.Id, err = self.handler.getAppEnv().GetManagers().ApiSession.Create(ctx, apiSession, nil) + if err != nil { + return err + } + + if err = identityMgr.Annotate(ctx, self.identity.Id, "apiSessionId", apiSession.Id); err != nil { + logger.WithError(err).Error("failed to cache new api session on router identity") + } + + apiSession, err = self.handler.getAppEnv().GetManagers().ApiSession.ReadInTx(ctx.Tx(), apiSession.Id) + return err + }) - apiSession, err = self.handler.getAppEnv().GetManagers().ApiSession.Read(apiSession.Id) if err != nil { self.err = internalError(err) return false diff --git a/controller/internal/routes/authenticate_router.go b/controller/internal/routes/authenticate_router.go index c684b8b2e..409ff21f5 100644 --- a/controller/internal/routes/authenticate_router.go +++ b/controller/internal/routes/authenticate_router.go @@ -170,7 +170,7 @@ func (ro *AuthRouter) authHandler(ae *env.AppEnv, rc *response.RequestContext, h sessionCerts = append(sessionCerts, sessionCert) } - sessionId, err := ae.Managers.ApiSession.Create(newApiSession, sessionCerts) + sessionId, err := ae.Managers.ApiSession.Create(nil, newApiSession, sessionCerts) if err != nil { rc.RespondWithError(err) diff --git a/controller/model/api_session_manager.go b/controller/model/api_session_manager.go index 887227436..d96135804 100644 --- a/controller/model/api_session_manager.go +++ b/controller/model/api_session_manager.go @@ -50,34 +50,41 @@ func (self *ApiSessionManager) newModelEntity() edgeEntity { return &ApiSession{} } -func (self *ApiSessionManager) Create(entity *ApiSession, sessionCerts []*ApiSessionCertificate) (string, error) { - entity.Id = cuid.New() //use cuids which are longer than shortids but are monotonic - - var apiSessionId string - err := self.env.GetDbProvider().GetDb().Update(func(tx *bbolt.Tx) error { - var err error - ctx := boltz.NewMutateContext(tx) - apiSessionId, err = self.createEntityInTx(ctx, entity) - - if err != nil { +func (self *ApiSessionManager) Create(ctx boltz.MutateContext, entity *ApiSession, sessionCerts []*ApiSessionCertificate) (string, error) { + if ctx == nil { + var apiSessionId string + err := self.env.GetDbProvider().GetDb().Update(func(tx *bbolt.Tx) error { + ctx = boltz.NewMutateContext(tx) + var err error + apiSessionId, err = self.CreateInCtx(ctx, entity, sessionCerts) return err + }) + if err != nil { + return "", err } + return apiSessionId, nil + } - for _, sessionCert := range sessionCerts { - sessionCert.ApiSessionId = apiSessionId - _, err := self.env.GetManagers().ApiSessionCertificate.createEntityInTx(ctx, sessionCert) + return self.CreateInCtx(ctx, entity, sessionCerts) +} - if err != nil { - return err - } - } - return nil - }) +func (self *ApiSessionManager) CreateInCtx(ctx boltz.MutateContext, entity *ApiSession, sessionCerts []*ApiSessionCertificate) (string, error) { + entity.Id = cuid.New() //use cuids which are longer than shortids but are monotonic + apiSessionId, err := self.createEntityInTx(ctx, entity) if err != nil { - self.MarkActivityById(apiSessionId) + return "", err } + for _, sessionCert := range sessionCerts { + sessionCert.ApiSessionId = apiSessionId + if _, err = self.env.GetManagers().ApiSessionCertificate.createEntityInTx(ctx, sessionCert); err != nil { + return "", err + } + } + + self.MarkActivityById(apiSessionId) + return apiSessionId, err } @@ -98,7 +105,7 @@ func (self *ApiSessionManager) ReadByToken(token string) (*ApiSession, error) { return modelApiSession, nil } -func (self *ApiSessionManager) readInTx(tx *bbolt.Tx, id string) (*ApiSession, error) { +func (self *ApiSessionManager) ReadInTx(tx *bbolt.Tx, id string) (*ApiSession, error) { modelApiSession := &ApiSession{} if err := self.readEntityInTx(tx, id, modelApiSession); err != nil { return nil, err @@ -202,7 +209,7 @@ func (self *ApiSessionManager) Stream(query string, collect func(*ApiSession, er for cursor := self.Store.IterateIds(tx, filter); cursor.IsValid(); cursor.Next() { current := cursor.Current() - apiSession, err := self.readInTx(tx, string(current)) + apiSession, err := self.ReadInTx(tx, string(current)) if err := collect(apiSession, err); err != nil { return err } @@ -240,7 +247,7 @@ func (self *ApiSessionManager) Query(query string) (*ApiSessionListResult, error func (self *ApiSessionManager) VisitFingerprintsForApiSessionId(apiSessionId string, visitor func(fingerprint string) bool) error { return self.GetDb().View(func(tx *bbolt.Tx) error { - apiSession, err := self.readInTx(tx, apiSessionId) + apiSession, err := self.ReadInTx(tx, apiSessionId) if err != nil { return errors.Wrapf(err, "could not query fingerprints by api session id [%s]", apiSessionId) } @@ -288,7 +295,7 @@ type ApiSessionListResult struct { func (result *ApiSessionListResult) collect(tx *bbolt.Tx, ids []string, queryMetaData *models.QueryMetaData) error { result.QueryMetaData = *queryMetaData for _, key := range ids { - ApiSession, err := result.manager.readInTx(tx, key) + ApiSession, err := result.manager.ReadInTx(tx, key) if err != nil { return err } diff --git a/controller/model/base_manager.go b/controller/model/base_manager.go index 897300f65..9db55c897 100644 --- a/controller/model/base_manager.go +++ b/controller/model/base_manager.go @@ -30,6 +30,8 @@ import ( "reflect" ) +const annotationsBucketName = "annotations" + type EntityManager interface { models.EntityRetriever[models.Entity] command.EntityDeleter @@ -381,6 +383,31 @@ func (self *baseEntityManager) iterateRelatedEntitiesInTx(tx *bbolt.Tx, id, fiel return nil } +func (self *baseEntityManager) Annotate(ctx boltz.MutateContext, entityId string, key, value string) error { + entityBucket := self.GetStore().GetEntityBucket(ctx.Tx(), []byte(entityId)) + if entityBucket == nil { + return boltz.NewNotFoundError(self.GetStore().GetEntityType(), "id", entityId) + } + annotationsBucket := entityBucket.GetOrCreatePath(annotationsBucketName) + annotationsBucket.SetString(key, value, nil) + return annotationsBucket.GetError() +} + +func (self *baseEntityManager) GetAnnotation(entityId string, key string) (*string, error) { + var result *string + err := self.GetDb().View(func(tx *bbolt.Tx) error { + entityBucket := self.GetStore().GetEntityBucket(tx, []byte(entityId)) + if entityBucket == nil { + return nil + } + if annotationsBucket := entityBucket.GetPath(annotationsBucketName); annotationsBucket != nil { + result = annotationsBucket.GetString(key) + } + return nil + }) + return result, err +} + type AndFieldChecker struct { first boltz.FieldChecker second boltz.FieldChecker diff --git a/controller/model/testing.go b/controller/model/testing.go index 1a7c11553..f5115e3f2 100644 --- a/controller/model/testing.go +++ b/controller/model/testing.go @@ -200,7 +200,7 @@ func (ctx *TestContext) requireNewApiSession(identity *Identity) *ApiSession { Identity: identity, LastActivityAt: time.Now(), } - _, err := ctx.managers.ApiSession.Create(entity, nil) + _, err := ctx.managers.ApiSession.Create(nil, entity, nil) ctx.NoError(err) return entity }