Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Mar 20, 2023
1 parent 5dc0184 commit 18ea1d9
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 76 deletions.
2 changes: 1 addition & 1 deletion identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ func (h *Handler) patchIdentities(w http.ResponseWriter, r *http.Request, _ http
res.Identities = make([]*IdentityPatchResponse, len(req.Identities))
// Array to look up the index of the identity in the identities array.
indexInIdentities := make([]int, len(req.Identities))
identities := make(Identities, 0, len(req.Identities))
identities := make([]*Identity, 0, len(req.Identities))

for i, patch := range req.Identities {
if patch.Create != nil {
Expand Down
6 changes: 2 additions & 4 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,7 @@ func (i *Identity) Validate() error {
return nil
}

type Identities []*Identity

func (i Identities) VerifiableAddresses() (res []VerifiableAddress) {
func VerifiableAddresses(i []*Identity) (res []VerifiableAddress) {
res = make([]VerifiableAddress, 0, len(i))
for _, id := range i {
res = append(res, id.VerifiableAddresses...)
Expand All @@ -413,7 +411,7 @@ func (i Identities) VerifiableAddresses() (res []VerifiableAddress) {
return res
}

func (i Identities) RecoveryAddresses() (res []RecoveryAddress) {
func RecoveryAddresses(i []*Identity) (res []RecoveryAddress) {
res = make([]RecoveryAddress, 0, len(i))
for _, id := range i {
res = append(res, id.RecoveryAddresses...)
Expand Down
5 changes: 3 additions & 2 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption
return m.r.PrivilegedIdentityPool().CreateIdentity(ctx, i)
}

func (m *Manager) CreateIdentities(ctx context.Context, identities Identities, opts ...ManagerOption) (err error) {
func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) {
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.Create")
defer otelx.End(span, &err)

Expand All @@ -103,7 +103,8 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities Identities, o
}
}

return m.r.PrivilegedIdentityPool().CreateIdentities(ctx, identities)
identities, err = m.r.PrivilegedIdentityPool().CreateIdentities(ctx, identities...)
return err
}

func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *ManagerOptions) (err error) {
Expand Down
2 changes: 1 addition & 1 deletion identity/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type (

// CreateIdentities creates multiple identities. It is capable of setting credentials without encoding. Will return an error
// if identity exists, backend connectivity is broken, or trait validation fails.
CreateIdentities(context.Context, Identities) error
CreateIdentities(context.Context, ...*Identity) ([]*Identity, error)

// UpdateIdentity updates an identity including its confidential / privileged / protected data.
UpdateIdentity(context.Context, *Identity) error
Expand Down
105 changes: 37 additions & 68 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,16 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, ident
return nil
}

func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, addresses []identity.VerifiableAddress) (err error) {
func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, identities ...*identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createVerifiableAddresses")
defer otelx.End(span, &err)

if err := p.GetConnection(ctx).Create(addresses); err != nil {
return err
for _, id := range identities {
for i := range id.VerifiableAddresses {
if err := p.GetConnection(ctx).Create(&id.VerifiableAddresses[i]); err != nil {
return err
}
}
}
return nil
}
Expand Down Expand Up @@ -331,9 +335,11 @@ func updateAssociation[T interface {
return nil
}

func (p *IdentityPersister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) {
p.normalizeRecoveryAddresses(ctx, id)
p.normalizeVerifiableAddresses(ctx, id)
func (p *IdentityPersister) normalizeAllAddressess(ctx context.Context, identities ...*identity.Identity) {
for _, id := range identities {
p.normalizeRecoveryAddresses(ctx, id)
p.normalizeVerifiableAddresses(ctx, id)
}
}

func (p *IdentityPersister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) {
Expand Down Expand Up @@ -370,12 +376,16 @@ func (p *IdentityPersister) normalizeRecoveryAddresses(ctx context.Context, id *
}
}

func (p *IdentityPersister) createRecoveryAddresses(ctx context.Context, addresses []identity.RecoveryAddress) (err error) {
func (p *IdentityPersister) createRecoveryAddresses(ctx context.Context, identities ...*identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createRecoveryAddresses")
defer otelx.End(span, &err)

if err := p.GetConnection(ctx).Create(addresses); err != nil {
return err
for _, id := range identities {
for i := range id.RecoveryAddresses {
if err := p.GetConnection(ctx).Create(&id.RecoveryAddresses[i]); err != nil {
return err
}
}
}

return nil
Expand All @@ -396,51 +406,16 @@ func (p *IdentityPersister) CreateIdentity(ctx context.Context, ident *identity.
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateIdentity")
defer otelx.End(span, &err)

ident.NID = p.NetworkID(ctx)

if ident.SchemaID == "" {
ident.SchemaID = p.r.Config().DefaultIdentityTraitsSchemaID(ctx)
var res []*identity.Identity
res, err = p.CreateIdentities(ctx, ident)
if err == nil && len(res) > 0 {
*ident = *res[0]
}

stateChangedAt := sqlxx.NullTime(time.Now())
ident.StateChangedAt = &stateChangedAt
if ident.State == "" {
ident.State = identity.StateActive
}

if len(ident.Traits) == 0 {
ident.Traits = identity.Traits("{}")
}

if err := p.InjectTraitsSchemaURL(ctx, ident); err != nil {
return err
}

if err := p.validateIdentity(ctx, ident); err != nil {
return err
}

return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
if err := tx.Create(ident); err != nil {
return sqlcon.HandleError(err)
}

p.normalizeAllAddressess(ctx, ident)

if err := p.createVerifiableAddresses(ctx, ident.VerifiableAddresses); err != nil {
return sqlcon.HandleError(err)
}

if err := p.createRecoveryAddresses(ctx, ident.RecoveryAddresses); err != nil {
return sqlcon.HandleError(err)
}

return sqlcon.HandleError(p.createIdentityCredentials(ctx, ident))
})
return err
}

func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities identity.Identities) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateIdentity")
func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...*identity.Identity) (res []*identity.Identity, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateIdentities")
defer otelx.End(span, &err)

for _, ident := range identities {
Expand All @@ -461,33 +436,27 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ide
}

if err := p.InjectTraitsSchemaURL(ctx, ident); err != nil {
return err
return nil, err
}

if err := p.validateIdentity(ctx, ident); err != nil {
return err
return nil, err
}
}

return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
identitiesCopy := make([]identity.Identity, len(identities))
for i := range identities {
identitiesCopy[i] = *identities[i]
}
if err := tx.Create(identitiesCopy); err != nil {
return sqlcon.HandleError(err)
}
for i := range identitiesCopy {
identities[i] = &identitiesCopy[i]
return identities, p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
// TODO(hperl): optimize
for _, id := range identities {
if err := tx.Create(id); err != nil {
return sqlcon.HandleError(err)
}
}

for _, ident := range identities {
p.normalizeAllAddressess(ctx, ident)
}
p.normalizeAllAddressess(ctx, identities...)

eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error { return p.createVerifiableAddresses(ctx, identities.VerifiableAddresses()) })
eg.Go(func() error { return p.createRecoveryAddresses(ctx, identities.RecoveryAddresses()) })
eg.Go(func() error { return p.createVerifiableAddresses(ctx, identities...) })
eg.Go(func() error { return p.createRecoveryAddresses(ctx, identities...) })
eg.Go(func() error { return p.createIdentityCredentials(ctx, identities...) })

return sqlcon.HandleError(eg.Wait())
Expand Down

0 comments on commit 18ea1d9

Please sign in to comment.