Skip to content

Commit

Permalink
refactor: persister query and setup more test data
Browse files Browse the repository at this point in the history
  • Loading branch information
Ajay Kelkar committed Feb 10, 2023
1 parent a101504 commit f3020d3
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 60 deletions.
43 changes: 34 additions & 9 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,21 +594,46 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
})

t.Run("case=find identity by its credentials identifier", func(t *testing.T) {
expected := passwordIdentity("", "find-identity-by-identifier@ory.sh")
expected.Traits = identity.Traits(`{}`)
var expectedIdentifiers []string
var expectedIdentities []*identity.Identity

require.NoError(t, p.CreateIdentity(ctx, expected))
createdIDs = append(createdIDs, expected.ID)
for _, c := range []identity.CredentialsType{
identity.CredentialsTypePassword,
identity.CredentialsTypeOIDC,
identity.CredentialsTypeWebAuthn,
} {
identityIdentifier := fmt.Sprintf("find-identity-by-identifier-%s@ory.sh", c)
expected := identity.NewIdentity("")
expected.SetCredentials(c, identity.Credentials{Type: c, Identifiers: []string{identityIdentifier}, Config: sqlxx.JSONRawMessage(`{}`)})

actual, err := p.FindByCredentialsIdentifier(ctx, "find-identity-by-identifier@ory.sh")
require.NoError(t, err)
require.NoError(t, p.CreateIdentity(ctx, expected))
createdIDs = append(createdIDs, expected.ID)
expectedIdentifiers = append(expectedIdentifiers, identityIdentifier)
expectedIdentities = append(expectedIdentities, expected)
}

expected.Credentials = nil
assertEqual(t, expected, actual)
for c, ct := range []identity.CredentialsType{
identity.CredentialsTypePassword,
identity.CredentialsTypeOIDC,
identity.CredentialsTypeWebAuthn,
} {
t.Run(ct.String(), func(t *testing.T) {
actual, err := p.FindByCredentialsIdentifier(ctx, expectedIdentifiers[c])
require.NoError(t, err)

expected := expectedIdentities[c]
assert.EqualValues(t, expected.Credentials[ct].ID, actual.Credentials[ct].ID)
assert.EqualValues(t, expected.Credentials[ct].Identifiers, actual.Credentials[ct].Identifiers)

expected.Credentials = nil
actual.Credentials = nil
assertEqual(t, expected, actual)
})
}

t.Run("not if on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
_, err := p.FindByCredentialsIdentifier(ctx, "find-identity-by-identifier@ory.sh")
_, err := p.FindByCredentialsIdentifier(ctx, "find-identity-by-identifier-password@ory.sh")
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})
})
Expand Down
107 changes: 56 additions & 51 deletions persistence/sql/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,7 @@ func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, match strin
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindByCredentialsIdentifier")
defer span.End()

nid := p.NetworkID(ctx)

var find struct {
IdentityID uuid.UUID `db:"identity_id"`
}

// #nosec G201
if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT
ic.identity_id
FROM %s ic
INNER JOIN %s ici on ic.id = ici.identity_credential_id
WHERE ici.identifier = ?
AND ic.nid = ?
AND ici.nid = ?`,
"identity_credentials",
"identity_credential_identifiers",
),
match,
nid,
nid,
).First(&find); err != nil {
return nil, sqlcon.HandleError(err)
}

i, err := p.GetIdentity(ctx, find.IdentityID, identity.ExpandDefault)
i, err := p.findIdentityByIdentifier(ctx, nil, match)
if err != nil {
return nil, err
}
Expand All @@ -123,17 +99,32 @@ func (p *Persister) FindByCredentialsTypeAndIdentifier(ctx context.Context, ct i
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindByCredentialsTypeAndIdentifier")
defer span.End()

nid := p.NetworkID(ctx)
// Force case-insensitivity and trimming for identifiers
match = p.normalizeIdentifier(ct, match)

i, err := p.findIdentityByIdentifier(ctx, &ct, match)
if err != nil {
return nil, nil, err
}

creds, ok := i.GetCredentials(ct)
if !ok {
return nil, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type \"%s\". This is a bug in the code.", ct))
}

return i.CopyWithoutCredentials(), creds, nil
}

func (p *Persister) findIdentityByIdentifier(ctx context.Context, ct *identity.CredentialsType, match string) (*identity.Identity, error) {
var find struct {
IdentityID uuid.UUID `db:"identity_id"`
}

// Force case-insensitivity and trimming for identifiers
match = p.normalizeIdentifier(ct, match)
nid := p.NetworkID(ctx)

// #nosec G201
if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT
if ct != nil {
// #nosec G201
if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT
ic.identity_id
FROM %s ic
INNER JOIN %s ict on ic.identity_credential_type_id = ict.id
Expand All @@ -142,33 +133,47 @@ WHERE ici.identifier = ?
AND ic.nid = ?
AND ici.nid = ?
AND ict.name = ?`,
"identity_credentials",
"identity_credential_types",
"identity_credential_identifiers",
),
match,
nid,
nid,
ct,
).First(&find); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, sqlcon.HandleError(err) // herodot.ErrNotFound.WithTrace(err).WithReasonf(`No identity matching credentials identifier "%s" could be found.`, match)
}
"identity_credentials",
"identity_credential_types",
"identity_credential_identifiers",
),
match,
nid,
nid,
ct,
).First(&find); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, sqlcon.HandleError(err) // herodot.ErrNotFound.WithTrace(err).WithReasonf(`No identity matching credentials identifier "%s" could be found.`, match)
}

return nil, nil, sqlcon.HandleError(err)
return nil, sqlcon.HandleError(err)
}
} else {
// #nosec G201
if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT
ic.identity_id
FROM %s ic
INNER JOIN %s ici on ic.id = ici.identity_credential_id
WHERE ici.identifier = ?
AND ic.nid = ?
AND ici.nid = ?`,
"identity_credentials",
"identity_credential_identifiers",
),
match,
nid,
nid,
).First(&find); err != nil {
return nil, sqlcon.HandleError(err)
}
}

i, err := p.GetIdentityConfidential(ctx, find.IdentityID)
i, err := p.GetIdentity(ctx, find.IdentityID, identity.ExpandEverything)
if err != nil {
return nil, nil, err
}

creds, ok := i.GetCredentials(ct)
if !ok {
return nil, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type \"%s\". This is a bug in the code.", ct))
return nil, err
}

return i.CopyWithoutCredentials(), creds, nil
return i, nil
}

func (p *Persister) findIdentityCredentialsType(ctx context.Context, ct identity.CredentialsType) (*identity.CredentialsTypeTable, error) {
Expand Down

0 comments on commit f3020d3

Please sign in to comment.