From 18ea1d9fb07e703986a0f4bd738866ef3aeaefea Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Mon, 20 Mar 2023 15:55:53 +0100 Subject: [PATCH] WIP --- identity/handler.go | 2 +- identity/identity.go | 6 +- identity/manager.go | 5 +- identity/pool.go | 2 +- .../sql/identity/persister_identity.go | 105 ++++++------------ 5 files changed, 44 insertions(+), 76 deletions(-) diff --git a/identity/handler.go b/identity/handler.go index c941e7e3f384..1493c7a25a99 100644 --- a/identity/handler.go +++ b/identity/handler.go @@ -495,7 +495,7 @@ func (h *Handler) patchIdentities(w http.ResponseWriter, r *http.Request, _ http res.Identities = make([]*IdentityPatchResponse, len(req.Identities)) // Array to look up the index of the identity in the identities array. indexInIdentities := make([]int, len(req.Identities)) - identities := make(Identities, 0, len(req.Identities)) + identities := make([]*Identity, 0, len(req.Identities)) for i, patch := range req.Identities { if patch.Create != nil { diff --git a/identity/identity.go b/identity/identity.go index 9694b914a0d2..98918d385056 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -402,9 +402,7 @@ func (i *Identity) Validate() error { return nil } -type Identities []*Identity - -func (i Identities) VerifiableAddresses() (res []VerifiableAddress) { +func VerifiableAddresses(i []*Identity) (res []VerifiableAddress) { res = make([]VerifiableAddress, 0, len(i)) for _, id := range i { res = append(res, id.VerifiableAddresses...) @@ -413,7 +411,7 @@ func (i Identities) VerifiableAddresses() (res []VerifiableAddress) { return res } -func (i Identities) RecoveryAddresses() (res []RecoveryAddress) { +func RecoveryAddresses(i []*Identity) (res []RecoveryAddress) { res = make([]RecoveryAddress, 0, len(i)) for _, id := range i { res = append(res, id.RecoveryAddresses...) diff --git a/identity/manager.go b/identity/manager.go index e857f40f6df1..3f4ef4ea96cd 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -88,7 +88,7 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption return m.r.PrivilegedIdentityPool().CreateIdentity(ctx, i) } -func (m *Manager) CreateIdentities(ctx context.Context, identities Identities, opts ...ManagerOption) (err error) { +func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) { ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.Create") defer otelx.End(span, &err) @@ -103,7 +103,8 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities Identities, o } } - return m.r.PrivilegedIdentityPool().CreateIdentities(ctx, identities) + identities, err = m.r.PrivilegedIdentityPool().CreateIdentities(ctx, identities...) + return err } func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *ManagerOptions) (err error) { diff --git a/identity/pool.go b/identity/pool.go index 08fbb6a06626..5a2689649254 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -64,7 +64,7 @@ type ( // CreateIdentities creates multiple identities. It is capable of setting credentials without encoding. Will return an error // if identity exists, backend connectivity is broken, or trait validation fails. - CreateIdentities(context.Context, Identities) error + CreateIdentities(context.Context, ...*Identity) ([]*Identity, error) // UpdateIdentity updates an identity including its confidential / privileged / protected data. UpdateIdentity(context.Context, *Identity) error diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index caef466552c8..e5bc55b314c4 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -274,12 +274,16 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, ident return nil } -func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, addresses []identity.VerifiableAddress) (err error) { +func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, identities ...*identity.Identity) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createVerifiableAddresses") defer otelx.End(span, &err) - if err := p.GetConnection(ctx).Create(addresses); err != nil { - return err + for _, id := range identities { + for i := range id.VerifiableAddresses { + if err := p.GetConnection(ctx).Create(&id.VerifiableAddresses[i]); err != nil { + return err + } + } } return nil } @@ -331,9 +335,11 @@ func updateAssociation[T interface { return nil } -func (p *IdentityPersister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) { - p.normalizeRecoveryAddresses(ctx, id) - p.normalizeVerifiableAddresses(ctx, id) +func (p *IdentityPersister) normalizeAllAddressess(ctx context.Context, identities ...*identity.Identity) { + for _, id := range identities { + p.normalizeRecoveryAddresses(ctx, id) + p.normalizeVerifiableAddresses(ctx, id) + } } func (p *IdentityPersister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) { @@ -370,12 +376,16 @@ func (p *IdentityPersister) normalizeRecoveryAddresses(ctx context.Context, id * } } -func (p *IdentityPersister) createRecoveryAddresses(ctx context.Context, addresses []identity.RecoveryAddress) (err error) { +func (p *IdentityPersister) createRecoveryAddresses(ctx context.Context, identities ...*identity.Identity) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createRecoveryAddresses") defer otelx.End(span, &err) - if err := p.GetConnection(ctx).Create(addresses); err != nil { - return err + for _, id := range identities { + for i := range id.RecoveryAddresses { + if err := p.GetConnection(ctx).Create(&id.RecoveryAddresses[i]); err != nil { + return err + } + } } return nil @@ -396,51 +406,16 @@ func (p *IdentityPersister) CreateIdentity(ctx context.Context, ident *identity. ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateIdentity") defer otelx.End(span, &err) - ident.NID = p.NetworkID(ctx) - - if ident.SchemaID == "" { - ident.SchemaID = p.r.Config().DefaultIdentityTraitsSchemaID(ctx) + var res []*identity.Identity + res, err = p.CreateIdentities(ctx, ident) + if err == nil && len(res) > 0 { + *ident = *res[0] } - - stateChangedAt := sqlxx.NullTime(time.Now()) - ident.StateChangedAt = &stateChangedAt - if ident.State == "" { - ident.State = identity.StateActive - } - - if len(ident.Traits) == 0 { - ident.Traits = identity.Traits("{}") - } - - if err := p.InjectTraitsSchemaURL(ctx, ident); err != nil { - return err - } - - if err := p.validateIdentity(ctx, ident); err != nil { - return err - } - - return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { - if err := tx.Create(ident); err != nil { - return sqlcon.HandleError(err) - } - - p.normalizeAllAddressess(ctx, ident) - - if err := p.createVerifiableAddresses(ctx, ident.VerifiableAddresses); err != nil { - return sqlcon.HandleError(err) - } - - if err := p.createRecoveryAddresses(ctx, ident.RecoveryAddresses); err != nil { - return sqlcon.HandleError(err) - } - - return sqlcon.HandleError(p.createIdentityCredentials(ctx, ident)) - }) + return err } -func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities identity.Identities) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateIdentity") +func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...*identity.Identity) (res []*identity.Identity, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateIdentities") defer otelx.End(span, &err) for _, ident := range identities { @@ -461,33 +436,27 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ide } if err := p.InjectTraitsSchemaURL(ctx, ident); err != nil { - return err + return nil, err } if err := p.validateIdentity(ctx, ident); err != nil { - return err + return nil, err } } - return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { - identitiesCopy := make([]identity.Identity, len(identities)) - for i := range identities { - identitiesCopy[i] = *identities[i] - } - if err := tx.Create(identitiesCopy); err != nil { - return sqlcon.HandleError(err) - } - for i := range identitiesCopy { - identities[i] = &identitiesCopy[i] + return identities, p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + // TODO(hperl): optimize + for _, id := range identities { + if err := tx.Create(id); err != nil { + return sqlcon.HandleError(err) + } } - for _, ident := range identities { - p.normalizeAllAddressess(ctx, ident) - } + p.normalizeAllAddressess(ctx, identities...) eg, ctx := errgroup.WithContext(ctx) - eg.Go(func() error { return p.createVerifiableAddresses(ctx, identities.VerifiableAddresses()) }) - eg.Go(func() error { return p.createRecoveryAddresses(ctx, identities.RecoveryAddresses()) }) + eg.Go(func() error { return p.createVerifiableAddresses(ctx, identities...) }) + eg.Go(func() error { return p.createRecoveryAddresses(ctx, identities...) }) eg.Go(func() error { return p.createIdentityCredentials(ctx, identities...) }) return sqlcon.HandleError(eg.Wait())