Skip to content

Commit

Permalink
fix: do not invalidate recovery addr on update (#2699)
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl authored Sep 2, 2022
1 parent a0d2bfb commit 1689bb9
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 28 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/ory/kratos

go 1.17
go 1.18

replace (
github.com/bradleyjkemp/cupaloy/v2 => github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3
Expand Down
6 changes: 6 additions & 0 deletions identity/identity_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package identity

import (
"context"
"fmt"
"time"

"github.com/gofrs/uuid"
Expand Down Expand Up @@ -55,6 +56,11 @@ func (a RecoveryAddress) ValidateNID() error {
return nil
}

// Hash returns a unique string representation for the recovery address.
func (a RecoveryAddress) Hash() string {
return fmt.Sprintf("%v|%v|%v|%v", a.Value, a.Via, a.IdentityID, a.NID)
}

func NewRecoveryEmailAddress(
value string,
identity uuid.UUID,
Expand Down
41 changes: 40 additions & 1 deletion identity/identity_recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package identity

import (
"testing"
"time"

"github.com/gofrs/uuid"

"github.com/stretchr/testify/assert"

"github.com/ory/kratos/x"
Expand All @@ -19,3 +19,42 @@ func TestNewRecoveryEmailAddress(t *testing.T) {
assert.Equal(t, iid, a.IdentityID)
assert.Equal(t, uuid.Nil, a.ID)
}

// TestRecoveryAddress_Hash tests that the hash considers all fields that are
// written to the database (ignoring some well-known fields like the ID or
// timestamps).
func TestRecoveryAddress_Hash(t *testing.T) {
cases := []struct {
name string
a RecoveryAddress
}{
{
name: "full fields",
a: RecoveryAddress{
ID: x.NewUUID(),
Value: "foo@bar.me",
Via: AddressTypeEmail,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
IdentityID: x.NewUUID(),
NID: x.NewUUID(),
},
}, {
name: "empty fields",
a: RecoveryAddress{},
}, {
name: "constructor",
a: *NewRecoveryEmailAddress("foo@ory.sh", x.NewUUID()),
},
}

for _, tc := range cases {
t.Run("case="+tc.name, func(t *testing.T) {
assert.Equal(t,
reflectiveHash(tc.a),
tc.a.Hash(),
)
})
}

}
6 changes: 6 additions & 0 deletions identity/identity_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package identity

import (
"context"
"fmt"
"time"

"github.com/gofrs/uuid"
Expand Down Expand Up @@ -129,3 +130,8 @@ func (a VerifiableAddress) GetNID() uuid.UUID {
func (a VerifiableAddress) ValidateNID() error {
return nil
}

// Hash returns a unique string representation for the recovery address.
func (a VerifiableAddress) Hash() string {
return fmt.Sprintf("%v|%v|%v|%v|%v|%v", a.Value, a.Verified, a.Via, a.Status, a.IdentityID, a.NID)
}
76 changes: 76 additions & 0 deletions identity/identity_verification_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package identity

import (
"fmt"
"reflect"
"strings"
"testing"
"time"

"github.com/gofrs/uuid"

Expand All @@ -25,3 +29,75 @@ func TestNewVerifiableEmailAddress(t *testing.T) {
assert.Equal(t, iid, a.IdentityID)
assert.Equal(t, uuid.Nil, a.ID)
}

var tagsIgnoredForHashing = map[string]struct{}{
"id": {},
"created_at": {},
"updated_at": {},
"verified_at": {},
}

func reflectiveHash(record any) string {
var (
val = reflect.ValueOf(record)
typ = reflect.TypeOf(record)
values = []string{}
)
for i := 0; i < val.NumField(); i++ {
dbTag, ok := typ.Field(i).Tag.Lookup("db")
if !ok {
continue
}
if _, ignore := tagsIgnoredForHashing[dbTag]; ignore {
continue
}
if !val.Field(i).CanInterface() {
continue
}
values = append(values, fmt.Sprintf("%v", val.Field(i).Interface()))
}
return strings.Join(values, "|")
}

// TestVerifiableAddress_Hash tests that the hash considers all fields that are
// written to the database (ignoring some well-known fields like the ID or
// timestamps).
func TestVerifiableAddress_Hash(t *testing.T) {
now := sqlxx.NullTime(time.Now())
cases := []struct {
name string
a VerifiableAddress
}{
{
name: "full fields",
a: VerifiableAddress{
ID: x.NewUUID(),
Value: "foo@bar.me",
Verified: false,
Via: AddressTypeEmail,
Status: VerifiableAddressStatusPending,
VerifiedAt: &now,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
IdentityID: x.NewUUID(),
NID: x.NewUUID(),
},
}, {
name: "empty fields",
a: VerifiableAddress{},
}, {
name: "constructor",
a: *NewVerifiableEmailAddress("foo@ory.sh", x.NewUUID()),
},
}

for _, tc := range cases {
t.Run("case="+tc.name, func(t *testing.T) {
assert.Equal(t,
reflectiveHash(tc.a),
tc.a.Hash(),
)
})
}

}
5 changes: 3 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

101 changes: 80 additions & 21 deletions persistence/sql/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,83 @@ func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.I
defer span.End()

for k := range i.VerifiableAddresses {
i.VerifiableAddresses[k].IdentityID = i.ID
i.VerifiableAddresses[k].NID = p.NetworkID(ctx)
i.VerifiableAddresses[k].Value = stringToLowerTrim(i.VerifiableAddresses[k].Value)
if err := p.GetConnection(ctx).Create(&i.VerifiableAddresses[k]); err != nil {
return err
}
}
return nil
}

func updateAssociation[T interface {
Hash() string
}](ctx context.Context, p *Persister, i *identity.Identity, inID []T) error {
var inDB []T
if err := p.GetConnection(ctx).
Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).
Order("id ASC").
All(&inDB); err != nil {

return sqlcon.HandleError(err)
}

newAssocs := make(map[string]*T)
oldAssocs := make(map[string]*T)
for i, a := range inID {
newAssocs[a.Hash()] = &inID[i]
}
for i, a := range inDB {
oldAssocs[a.Hash()] = &inDB[i]
}

// Subtle: we delete the old associations from the DB first, because else
// they could cause UNIQUE constraints to fail on insert.
for h, a := range oldAssocs {
if _, found := newAssocs[h]; found {
newAssocs[h] = nil // Ignore associations that are already in the db.
} else {
if err := p.GetConnection(ctx).Destroy(a); err != nil {
return sqlcon.HandleError(err)
}
}
}

for _, a := range newAssocs {
if a != nil {
if err := p.GetConnection(ctx).Create(a); err != nil {
return sqlcon.HandleError(err)
}
}
}

return nil
}

func (p *Persister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) {
p.normalizeRecoveryAddresses(ctx, id)
p.normalizeVerifiableAddresses(ctx, id)
}

func (p *Persister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) {
for k := range id.VerifiableAddresses {
id.VerifiableAddresses[k].IdentityID = id.ID
id.VerifiableAddresses[k].NID = p.NetworkID(ctx)
id.VerifiableAddresses[k].Value = stringToLowerTrim(id.VerifiableAddresses[k].Value)
}
}

func (p *Persister) normalizeRecoveryAddresses(ctx context.Context, id *identity.Identity) {
for k := range id.RecoveryAddresses {
id.RecoveryAddresses[k].IdentityID = id.ID
id.RecoveryAddresses[k].NID = p.NetworkID(ctx)
id.RecoveryAddresses[k].Value = stringToLowerTrim(id.RecoveryAddresses[k].Value)
}
}

func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Identity) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createRecoveryAddresses")
defer span.End()

for k := range i.RecoveryAddresses {
i.RecoveryAddresses[k].IdentityID = i.ID
i.RecoveryAddresses[k].NID = p.NetworkID(ctx)
i.RecoveryAddresses[k].Value = stringToLowerTrim(i.RecoveryAddresses[k].Value)
if err := p.GetConnection(ctx).Create(&i.RecoveryAddresses[k]); err != nil {
return err
}
Expand Down Expand Up @@ -285,6 +344,8 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er
return sqlcon.HandleError(err)
}

p.normalizeAllAddressess(ctx, i)

if err := p.createVerifiableAddresses(ctx, i); err != nil {
return sqlcon.HandleError(err)
}
Expand Down Expand Up @@ -350,27 +411,25 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er
return sql.ErrNoRows
}

for _, tn := range []string{
new(identity.Credentials).TableName(ctx),
new(identity.VerifiableAddress).TableName(ctx),
new(identity.RecoveryAddress).TableName(ctx),
} {
/* #nosec G201 TableName is static */
if err := tx.RawQuery(fmt.Sprintf(
`DELETE FROM %s WHERE identity_id = ? AND nid = ?`, tn), i.ID, p.NetworkID(ctx)).Exec(); err != nil {
return err
}
p.normalizeAllAddressess(ctx, i)
if err := updateAssociation(ctx, p, i, i.RecoveryAddresses); err != nil {
return err
}

if err := p.update(WithTransaction(ctx, tx), i); err != nil {
if err := updateAssociation(ctx, p, i, i.VerifiableAddresses); err != nil {
return err
}

if err := p.createVerifiableAddresses(ctx, i); err != nil {
return err
/* #nosec G201 TableName is static */
if err := tx.RawQuery(
fmt.Sprintf(
`DELETE FROM %s WHERE identity_id = ? AND nid = ?`,
new(identity.Credentials).TableName(ctx)),
i.ID, p.NetworkID(ctx)).Exec(); err != nil {

return sqlcon.HandleError(err)
}

if err := p.createRecoveryAddresses(ctx, i); err != nil {
if err := p.update(WithTransaction(ctx, tx), i); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func TestPersister_Transaction(t *testing.T) {
Traits: ri.Traits(`{}`),
}
errMessage := "failing because why not"
err := p.Transaction(context.Background(), func(ctx context.Context, connection *pop.Connection) error {
err := p.Transaction(context.Background(), func(_ context.Context, connection *pop.Connection) error {
require.NoError(t, connection.Create(i))
return errors.Errorf(errMessage)
})
Expand Down
27 changes: 25 additions & 2 deletions selfservice/strategy/link/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,31 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
assert.NotEqual(t, expected.Token, actual.Token)
assert.EqualValues(t, expected.FlowID, actual.FlowID)

_, err = p.UseRecoveryToken(ctx, f.ID, expected.Token)
require.Error(t, err)
t.Run("double spend", func(t *testing.T) {
_, err = p.UseRecoveryToken(ctx, f.ID, expected.Token)
require.Error(t, err)
})
})

t.Run("case=update to identity should not invalidate token", func(t *testing.T) {
expected, f := newRecoveryToken(t, "some-user@ory.sh")

require.NoError(t, p.CreateRecoveryToken(ctx, expected))
id, err := p.GetIdentity(ctx, expected.IdentityID)
require.NoError(t, err)
require.NoError(t, p.UpdateIdentity(ctx, id))

actual, err := p.UseRecoveryToken(ctx, f.ID, expected.Token)
require.NoError(t, err)
assert.Equal(t, nid, actual.NID)
assert.Equal(t, expected.IdentityID, actual.IdentityID)
assert.NotEqual(t, expected.Token, actual.Token)
assert.EqualValues(t, expected.FlowID, actual.FlowID)

t.Run("double spend", func(t *testing.T) {
_, err = p.UseRecoveryToken(ctx, f.ID, expected.Token)
require.Error(t, err)
})
})

})
Expand Down

0 comments on commit 1689bb9

Please sign in to comment.