From 3907f1416c1085b2e51665ee3ec30cc8690fac7c Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Sat, 26 Jun 2021 12:42:58 -0700 Subject: [PATCH] Simplify MultiStorage methods --- storage/allmulti/allmulti.go | 45 ++++++++++++++++----------------- storage/allmulti/bstoken.go | 24 ++++++++---------- storage/allmulti/certauth.go | 48 ++++++++++++++++-------------------- storage/allmulti/push.go | 11 ++++----- storage/allmulti/pushcert.go | 36 ++++++++++++--------------- storage/allmulti/queue.go | 47 ++++++++++++++--------------------- 6 files changed, 94 insertions(+), 117 deletions(-) diff --git a/storage/allmulti/allmulti.go b/storage/allmulti/allmulti.go index f205316..b325286 100644 --- a/storage/allmulti/allmulti.go +++ b/storage/allmulti/allmulti.go @@ -22,35 +22,36 @@ func New(logger log.Logger, stores ...storage.AllStorage) *MultiAllStorage { return &MultiAllStorage{logger: logger, stores: stores} } -func (ms *MultiAllStorage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error { - finalErr := ms.stores[0].StoreAuthenticate(r, msg) +type storageErrorer func(storage.AllStorage) error + +func (ms *MultiAllStorage) runAndLogOthers(storageCallback storageErrorer) { for n, storage := range ms.stores[1:] { - if err := storage.StoreAuthenticate(r, msg); err != nil { - ms.logger.Info("method", "StoreAuthenticate", "storage", n+1, "err", err) - continue + if err := storageCallback(storage); err != nil { + ms.logger.Info("msg", n+1, "err", err) } } - return finalErr +} + +func (ms *MultiAllStorage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error { + err := ms.stores[0].StoreAuthenticate(r, msg) + ms.runAndLogOthers(func(s storage.AllStorage) error { + return s.StoreAuthenticate(r, msg) + }) + return err } func (ms *MultiAllStorage) StoreTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error { - finalErr := ms.stores[0].StoreTokenUpdate(r, msg) - for n, storage := range ms.stores[1:] { - if err := storage.StoreTokenUpdate(r, msg); err != nil { - ms.logger.Info("method", "StoreTokenUpdate", "storage", n+1, "err", err) - continue - } - } - return finalErr + err := ms.stores[0].StoreTokenUpdate(r, msg) + ms.runAndLogOthers(func(s storage.AllStorage) error { + return s.StoreTokenUpdate(r, msg) + }) + return err } func (ms *MultiAllStorage) Disable(r *mdm.Request) error { - finalErr := ms.stores[0].Disable(r) - for n, storage := range ms.stores[1:] { - if err := storage.Disable(r); err != nil { - ms.logger.Info("method", "Disable", "storage", n+1, "err", err) - continue - } - } - return finalErr + err := ms.stores[0].Disable(r) + ms.runAndLogOthers(func(s storage.AllStorage) error { + return s.Disable(r) + }) + return err } diff --git a/storage/allmulti/bstoken.go b/storage/allmulti/bstoken.go index e08262a..5f94186 100644 --- a/storage/allmulti/bstoken.go +++ b/storage/allmulti/bstoken.go @@ -2,26 +2,22 @@ package allmulti import ( "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/storage" ) func (ms *MultiAllStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error { - finalErr := ms.stores[0].StoreBootstrapToken(r, msg) - for n, storage := range ms.stores[1:] { - if err := storage.StoreBootstrapToken(r, msg); err != nil { - ms.logger.Info("method", "StoreBootstrapToken", "storage", n+1, "err", err) - continue - } - } - return finalErr + err := ms.stores[0].StoreBootstrapToken(r, msg) + ms.runAndLogOthers(func(s storage.AllStorage) error { + return s.StoreBootstrapToken(r, msg) + }) + return err } func (ms *MultiAllStorage) RetrieveBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { finalToken, finalErr := ms.stores[0].RetrieveBootstrapToken(r, msg) - for n, storage := range ms.stores[1:] { - if _, err := storage.RetrieveBootstrapToken(r, msg); err != nil { - ms.logger.Info("method", "RetrieveBootstrapToken", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, err := s.RetrieveBootstrapToken(r, msg) + return err + }) return finalToken, finalErr } diff --git a/storage/allmulti/certauth.go b/storage/allmulti/certauth.go index b57644f..5055ab9 100644 --- a/storage/allmulti/certauth.go +++ b/storage/allmulti/certauth.go @@ -1,47 +1,41 @@ package allmulti -import "github.com/micromdm/nanomdm/mdm" +import ( + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/storage" +) func (ms *MultiAllStorage) HasCertHash(r *mdm.Request, hash string) (bool, error) { hasFinal, finalErr := ms.stores[0].HasCertHash(r, hash) - for n, storage := range ms.stores[1:] { - if _, err := storage.HasCertHash(r, hash); err != nil { - ms.logger.Info("method", "HasCertHash", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, err := s.HasCertHash(r, hash) + return err + }) return hasFinal, finalErr } func (ms *MultiAllStorage) EnrollmentHasCertHash(r *mdm.Request, hash string) (bool, error) { hasFinal, finalErr := ms.stores[0].EnrollmentHasCertHash(r, hash) - for n, storage := range ms.stores[1:] { - if _, err := storage.EnrollmentHasCertHash(r, hash); err != nil { - ms.logger.Info("method", "EnrollmentHasCertHash", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, err := s.EnrollmentHasCertHash(r, hash) + return err + }) return hasFinal, finalErr } func (ms *MultiAllStorage) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) { isAssocFinal, finalErr := ms.stores[0].IsCertHashAssociated(r, hash) - for n, storage := range ms.stores[1:] { - if _, err := storage.IsCertHashAssociated(r, hash); err != nil { - ms.logger.Info("method", "IsCertHashAssociated", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, err := s.IsCertHashAssociated(r, hash) + return err + }) return isAssocFinal, finalErr } func (ms *MultiAllStorage) AssociateCertHash(r *mdm.Request, hash string) error { - finalErr := ms.stores[0].AssociateCertHash(r, hash) - for n, storage := range ms.stores[1:] { - if err := storage.AssociateCertHash(r, hash); err != nil { - ms.logger.Info("method", "AssociateCertHash", "storage", n+1, "err", err) - continue - } - } - return finalErr + err := ms.stores[0].AssociateCertHash(r, hash) + ms.runAndLogOthers(func(s storage.AllStorage) error { + return s.AssociateCertHash(r, hash) + }) + return err } diff --git a/storage/allmulti/push.go b/storage/allmulti/push.go index 2db9c5d..a7ab571 100644 --- a/storage/allmulti/push.go +++ b/storage/allmulti/push.go @@ -4,15 +4,14 @@ import ( "context" "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/storage" ) func (ms *MultiAllStorage) RetrievePushInfo(ctx context.Context, ids []string) (map[string]*mdm.Push, error) { finalMap, finalErr := ms.stores[0].RetrievePushInfo(ctx, ids) - for n, storage := range ms.stores[1:] { - if _, err := storage.RetrievePushInfo(ctx, ids); err != nil { - ms.logger.Info("method", "RetrievePushInfo", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, err := s.RetrievePushInfo(ctx, ids) + return err + }) return finalMap, finalErr } diff --git a/storage/allmulti/pushcert.go b/storage/allmulti/pushcert.go index 0d89758..5fc311f 100644 --- a/storage/allmulti/pushcert.go +++ b/storage/allmulti/pushcert.go @@ -3,37 +3,33 @@ package allmulti import ( "context" "crypto/tls" + + "github.com/micromdm/nanomdm/storage" ) func (ms *MultiAllStorage) IsPushCertStale(ctx context.Context, topic string, staleToken string) (bool, error) { finalStale, finalErr := ms.stores[0].IsPushCertStale(ctx, topic, staleToken) - for n, storage := range ms.stores[1:] { - if _, err := storage.IsPushCertStale(ctx, topic, staleToken); err != nil { - ms.logger.Info("method", "IsPushCertStale", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, err := s.IsPushCertStale(ctx, topic, staleToken) + return err + }) return finalStale, finalErr } func (ms *MultiAllStorage) RetrievePushCert(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error) { finalCert, finalToken, finalErr := ms.stores[0].RetrievePushCert(ctx, topic) - for n, storage := range ms.stores[1:] { - if _, _, err := storage.RetrievePushCert(ctx, topic); err != nil { - ms.logger.Info("method", "RetrievePushCert", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, _, err := s.RetrievePushCert(ctx, topic) + return err + }) + return finalCert, finalToken, finalErr } func (ms *MultiAllStorage) StorePushCert(ctx context.Context, pemCert, pemKey []byte) error { - finalErr := ms.stores[0].StorePushCert(ctx, pemCert, pemKey) - for n, storage := range ms.stores[1:] { - if err := storage.StorePushCert(ctx, pemCert, pemKey); err != nil { - ms.logger.Info("method", "StorePushCert", "storage", n+1, "err", err) - continue - } - } - return finalErr + err := ms.stores[0].StorePushCert(ctx, pemCert, pemKey) + ms.runAndLogOthers(func(s storage.AllStorage) error { + return s.StorePushCert(ctx, pemCert, pemKey) + }) + return err } diff --git a/storage/allmulti/queue.go b/storage/allmulti/queue.go index b99ce8b..4378b3b 100644 --- a/storage/allmulti/queue.go +++ b/storage/allmulti/queue.go @@ -4,48 +4,39 @@ import ( "context" "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/storage" ) func (ms *MultiAllStorage) StoreCommandReport(r *mdm.Request, report *mdm.CommandResults) error { - finalErr := ms.stores[0].StoreCommandReport(r, report) - for n, storage := range ms.stores[1:] { - if err := storage.StoreCommandReport(r, report); err != nil { - ms.logger.Info("method", "StoreCommandReport", "storage", n+1, "err", err) - continue - } - } - return finalErr + err := ms.stores[0].StoreCommandReport(r, report) + ms.runAndLogOthers(func(s storage.AllStorage) error { + return s.StoreCommandReport(r, report) + }) + return err } func (ms *MultiAllStorage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) { skipFinal, finalErr := ms.stores[0].RetrieveNextCommand(r, skipNotNow) - for n, storage := range ms.stores[1:] { - if _, err := storage.RetrieveNextCommand(r, skipNotNow); err != nil { - ms.logger.Info("method", "RetrieveNextCommand", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, err := s.RetrieveNextCommand(r, skipNotNow) + return err + }) return skipFinal, finalErr } func (ms *MultiAllStorage) ClearQueue(r *mdm.Request) error { - finalErr := ms.stores[0].ClearQueue(r) - for n, storage := range ms.stores[1:] { - if err := storage.ClearQueue(r); err != nil { - ms.logger.Info("method", "ClearQueue", "storage", n+1, "err", err) - continue - } - } - return finalErr + err := ms.stores[0].ClearQueue(r) + ms.runAndLogOthers(func(s storage.AllStorage) error { + return s.ClearQueue(r) + }) + return err } func (ms *MultiAllStorage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) { finalMap, finalErr := ms.stores[0].EnqueueCommand(ctx, id, cmd) - for n, storage := range ms.stores[1:] { - if _, err := storage.EnqueueCommand(ctx, id, cmd); err != nil { - ms.logger.Info("method", "EnqueueCommand", "storage", n+1, "err", err) - continue - } - } + ms.runAndLogOthers(func(s storage.AllStorage) error { + _, err := s.EnqueueCommand(ctx, id, cmd) + return err + }) return finalMap, finalErr }